/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.data;

import com.ibm.bi.predict.data.Category;
import com.ibm.bi.predict.data.Config;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.exceptions.InvalidDataException;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import com.ibm.bi.predict.utils.Tuple;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;

public class CategoriesMerger {
    private final int maxCategoryCount;
    private final double minFrequencyPercent;
    private final int minUnmergedCategoryFrequency;
    private static final Logger LOG = PredictLoggerFactory.getLogger(CategoriesMerger.class);

    public CategoriesMerger(Config config) {
        this(config.maxCategoriesToProcess(), config.minFrequencyPercent(), config.minUnmergedCategoryFrequency());
    }

    public CategoriesMerger(int maxCategoryCount, double minFrequencyPercent, int minUnmergedCategoryFrequency) {
        this.maxCategoryCount = maxCategoryCount;
        this.minFrequencyPercent = minFrequencyPercent;
        this.minUnmergedCategoryFrequency = minUnmergedCategoryFrequency;
    }

    public DataColumn mergeCategories(DataColumn column) {
        if (column.getType() == FieldType.NUMERICAL) {
            return column;
        }
        Map<Double, Integer> categoryFrequencyMap = this.getFrequencyMap(column);
        int maxCategories = Math.min(this.maxCategoryCount, this.minFrequencyIndex(categoryFrequencyMap.values()) + 1);
        if (categoryFrequencyMap.size() <= maxCategories) {
            if (categoryFrequencyMap.size() < 2) {
                column.addStatus(DataColumn.Status.DEGENERATE);
            }
            return column;
        }
        List<List<Double>> splittedCategories = this.getCategoriesToKeepAndMerge(categoryFrequencyMap, maxCategories);
        List<Double> categoriesToMerge = splittedCategories.get(1);
        int sumNotMerged = CategoriesMerger.calculateTotalKeepFrequency(categoryFrequencyMap, splittedCategories.get(0));
        String mergedCatStr = categoriesToMerge.toString();
        if (mergedCatStr.length() > 2000) {
            mergedCatStr = mergedCatStr.substring(0, 2000) + " ...";
        }
        LOG.debug("Those categories were merged into a single category for field {}: [{}]", (Object)column.getId(), (Object)mergedCatStr);
        if (this.isDegenerate(categoryFrequencyMap, splittedCategories.get(0), column.rowCount(), column.getId())) {
            column.addStatus(DataColumn.Status.DEGENERATE);
        }
        column.addStatus(DataColumn.Status.MERGED_CATEGORIES);
        double[] newData = this.encodeColumnData(column, categoriesToMerge, categoryFrequencyMap.keySet());
        List<Category> mergedCategories = this.mergeCategories(column, splittedCategories.get(0), categoriesToMerge);
        DataColumn newColumn = new DataColumn(column.getMeta(), mergedCategories, newData, column.getStatus(), column.getIndex());
        newColumn.getSummaryData().setMergedCategories(mergedCatStr);
        newColumn.getSummaryData().setNonMergedCategoryProportion((double)sumNotMerged / (double)column.rowCount());
        return newColumn;
    }

    private List<Category> mergeCategories(DataColumn column, List<Double> categoriesToKeep, List<Double> categoriesToMerge) {
        List<Category> mergedCategories = new ArrayList<Category>();
        List<Category> originalCategories = column.getCategories();
        ArrayList<Double> allCategories = new ArrayList<Double>();
        allCategories.addAll(categoriesToMerge);
        allCategories.addAll(categoriesToKeep);
        Collections.sort(allCategories);
        if (categoriesToKeep.size() > categoriesToMerge.size()) {
            ArrayList<Category> categoriesToRemove = new ArrayList<Category>();
            for (double cat : categoriesToMerge) {
                int idx = allCategories.indexOf(cat);
                Category categoryToRemove = originalCategories.get(idx);
                categoriesToRemove.add(categoryToRemove);
            }
            for (Category category : categoriesToRemove) {
                originalCategories.remove(category);
            }
            mergedCategories = originalCategories;
        } else {
            for (Double cat : categoriesToKeep) {
                int idx = allCategories.indexOf(cat);
                Category category = originalCategories.get(idx);
                mergedCategories.add(category);
            }
        }
        mergedCategories.add(Category.OTHER);
        return mergedCategories;
    }

    protected Map<Double, Integer> getFrequencyMap(DataColumn colData) {
        HashMap<Double, Integer> categoryFrequencyMap = new HashMap<Double, Integer>();
        for (int i = 0; i < colData.rowCount(); ++i) {
            double d = colData.getValue(i);
            Integer value = categoryFrequencyMap.getOrDefault(d, 0);
            categoryFrequencyMap.put(d, 1 + value);
        }
        return categoryFrequencyMap;
    }

    protected Map<Double, Integer> sortMapByDecendingValue(Map<Double, Integer> map) {
        return map.entrySet().stream().sorted(Map.Entry.comparingByValue(Collections.reverseOrder()).thenComparing(Map.Entry.comparingByKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (e1, e2) -> e1, LinkedHashMap::new));
    }

    protected List<List<Double>> getCategoriesToKeepAndMerge(Map<Double, Integer> categoryFrequencyMap, int maxCategories) {
        Set<Map.Entry<Double, Integer>> entrySet = categoryFrequencyMap.entrySet();
        PriorityQueue<Tuple<Double, Integer>> queue = new PriorityQueue<Tuple<Double, Integer>>(entrySet.size(), new CategoryValueComparable());
        entrySet.forEach(e -> queue.add(Tuple.of(e.getKey(), e.getValue())));
        ArrayList<Double> allCategories = new ArrayList<Double>(categoryFrequencyMap.keySet());
        ArrayList<Object> categoriesToMerge = new ArrayList();
        ArrayList<Object> categoriesToKeep = new ArrayList<Object>();
        if (categoryFrequencyMap.size() > 2 * maxCategories) {
            for (int i = 0; i < maxCategories - 1; ++i) {
                categoriesToKeep.add(((Tuple)queue.poll())._1);
            }
            for (Double d : categoriesToKeep) {
                allCategories.remove(d);
            }
            categoriesToMerge = allCategories;
        } else {
            int i;
            for (i = 0; i < maxCategories - 1; ++i) {
                categoriesToKeep.add(((Tuple)queue.poll())._1);
            }
            for (i = maxCategories - 1; i < categoryFrequencyMap.size(); ++i) {
                categoriesToMerge.add(((Tuple)queue.poll())._1);
            }
        }
        ArrayList<List<Double>> result = new ArrayList<List<Double>>();
        result.add(categoriesToKeep);
        result.add(categoriesToMerge);
        return result;
    }

    protected double[] encodeColumnData(DataColumn colData, List<Double> categoriesToBeMerged, Set<Double> allCategories) {
        if (categoriesToBeMerged == null || categoriesToBeMerged.isEmpty()) {
            throw new InvalidDataException("Invalid categories to be merged.");
        }
        TreeSet<Double> goodCategories = new TreeSet<Double>(allCategories);
        for (Double v : categoriesToBeMerged) {
            goodCategories.remove(v);
        }
        HashMap<Double, Integer> goodCategoryMap = new HashMap<Double, Integer>();
        for (Double v : goodCategories) {
            goodCategoryMap.put(v, goodCategoryMap.size());
        }
        return CategoriesMerger.modifyData(colData, goodCategoryMap);
    }

    boolean isDegenerate(Map<Double, Integer> categoryFrequencyMap, List<Double> categoriesKeep, int rowCount, String fieldName) {
        if (categoryFrequencyMap.size() < 2) {
            return true;
        }
        int sum = CategoriesMerger.calculateTotalKeepFrequency(categoryFrequencyMap, categoriesKeep);
        double ratio = (double)sum / (double)rowCount;
        if (ratio < this.minFrequencyPercent) {
            String msg = "The field [" + fieldName + "] is degenerate with frequency ratio: " + String.format("%.2f", ratio);
            LOG.debug(msg);
        }
        return ratio < this.minFrequencyPercent;
    }

    private static double[] modifyData(DataColumn colData, Map<Double, Integer> goodCategoryMap) {
        int newCategory = goodCategoryMap.size();
        double[] newValues = new double[colData.rowCount()];
        for (int i = 0; i < newValues.length; ++i) {
            Integer index = goodCategoryMap.get(colData.getValue(i));
            newValues[i] = index == null ? (double)newCategory : (double)index.intValue();
        }
        return newValues;
    }

    private static int calculateTotalKeepFrequency(Map<Double, Integer> categoryFrequencyMap, List<Double> categoriesToKeep) {
        int sum = 0;
        for (Double category : categoriesToKeep) {
            sum += categoryFrequencyMap.get(category).intValue();
        }
        return sum;
    }

    private int minFrequencyIndex(Collection<Integer> frequencies) {
        int countValueGreatEqualThanMin = 0;
        for (Integer frequency : frequencies) {
            if (frequency < this.minUnmergedCategoryFrequency || ++countValueGreatEqualThanMin < this.maxCategoryCount) continue;
            return this.maxCategoryCount;
        }
        return Math.max(countValueGreatEqualThanMin, 1);
    }

    private static class CategoryValueComparable
    implements Comparator<Tuple<Double, Integer>> {
        private CategoryValueComparable() {
        }

        @Override
        public int compare(Tuple<Double, Integer> e1, Tuple<Double, Integer> e2) {
            return (Integer)e2._2 - (Integer)e1._2;
        }
    }
}

