/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.data;

import com.ibm.bi.predict.data.Category;
import com.ibm.bi.predict.data.Config;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.exceptions.BinningException;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;

public class Binner {
    private final int binCount;
    private final int minCategoriesToRequireBinning;
    private final boolean binMissingValues;
    private static final Logger LOG = PredictLoggerFactory.getLogger(Binner.class);

    public Binner(Config config) {
        this(config.binCount(), config.missingValuesInSeparateCategory());
    }

    public Binner(int binCount, boolean binMissingValues) {
        this.binCount = binCount;
        this.minCategoriesToRequireBinning = 2 * binCount;
        this.binMissingValues = binMissingValues;
    }

    public DataColumn bin(DataColumn column) {
        List<Category> categories;
        if (column.getType() == FieldType.CATEGORICAL) {
            return column;
        }
        DataColumn result = column.copy().addStatus(DataColumn.Status.BINNED);
        double[] data = Binner.makeSortedCopy(column);
        int validCount = Binner.countNonMissingValues(data);
        result.getSummaryData().setMissingValuesCount(data.length - validCount);
        Map<Double, Double> indexMap = this.createReIndexMapping(data, validCount);
        if (indexMap != null) {
            this.replaceData(result, indexMap::get, indexMap.size());
            categories = Binner.makeIndexingCategories(indexMap);
        } else {
            int zeroInflatedCount = this.countInflatedZeroes(data);
            if (zeroInflatedCount > 0) {
                result.addStatus(DataColumn.Status.ZERO_INFLATED);
                result.getSummaryData().setZerosCount(zeroInflatedCount);
            }
            double[] bins = this.createBins(data, zeroInflatedCount, validCount);
            this.replaceData(result, x -> Binner.findBin(x, bins), bins.length);
            categories = Binner.makeBinCategories(bins);
            result.getSummaryData().setBinBoundaries(bins);
        }
        if (this.binMissingValues && validCount < data.length) {
            result.addStatus(DataColumn.Status.ADDED_MISSING);
            categories.add(Category.MISSING);
        }
        result.setCategories(categories);
        if (result.getCategoryCount() < 2) {
            result.addStatus(DataColumn.Status.DEGENERATE);
        }
        return result;
    }

    private static List<Category> makeBinCategories(double[] values) {
        ArrayList<Category> result = new ArrayList<Category>();
        for (int i = 0; i < values.length; ++i) {
            double low = i == 0 ? Double.NEGATIVE_INFINITY : values[i - 1];
            double high = values[i];
            result.add(Category.forBin(low, high));
        }
        return result;
    }

    private static List<Category> makeIndexingCategories(Map<Double, Double> map) {
        List values = map.keySet().stream().sorted().collect(Collectors.toList());
        ArrayList<Category> result = new ArrayList<Category>();
        for (int i = 0; i < values.size(); ++i) {
            double v = (Double)values.get(i);
            if (i == 0) {
                v = Double.NEGATIVE_INFINITY;
            }
            double next = i == values.size() - 1 ? Double.POSITIVE_INFINITY : (Double)values.get(i + 1);
            result.add(Category.forBin(v, next));
        }
        return result;
    }

    private Map<Double, Double> createReIndexMapping(double[] data, int validCount) {
        TreeSet<Double> uniques = new TreeSet<Double>();
        for (int i = 0; i < validCount; ++i) {
            if (!uniques.add(data[i]) || uniques.size() <= this.minCategoriesToRequireBinning) continue;
            return null;
        }
        HashMap<Double, Double> result = new HashMap<Double, Double>();
        for (Double v : uniques) {
            result.put(v, Double.valueOf(result.size()));
        }
        return result;
    }

    private int countInflatedZeroes(double[] data) {
        int n = data.length;
        int secondBin = (int)Math.ceil(2.0 * (double)n / (double)this.binCount) - 1;
        if (n < 4) {
            return 0;
        }
        if (!NumericUtils.equals((double)data[0], (double)0.0)) {
            return 0;
        }
        if (!NumericUtils.equals((double)data[secondBin], (double)0.0)) {
            return 0;
        }
        if (NumericUtils.equals((double)data[n - 1], (double)data[0], (double)0.0)) {
            return 0;
        }
        int nonZeroIndex = -Arrays.binarySearch(data, Double.MIN_VALUE) - 1;
        assert (NumericUtils.equals((double)data[nonZeroIndex - 1], (double)0.0) && data[nonZeroIndex] > 0.0);
        return nonZeroIndex;
    }

    public static double[] makeSortedCopy(DataColumn data) {
        double[] sortedData = new double[data.rowCount()];
        for (int i = 0; i < sortedData.length; ++i) {
            sortedData[i] = data.getValue(i);
        }
        Arrays.sort(sortedData);
        return sortedData;
    }

    private void replaceData(DataColumn data, DoubleUnaryOperator method, int categoryCount) {
        if (this.binMissingValues) {
            data.replaceValues(v -> NumericUtils.isMissingValue((double)v) ? (double)categoryCount : method.applyAsDouble(v));
        } else {
            data.replaceValues(v -> NumericUtils.isMissingValue((double)v) ? v : method.applyAsDouble(v));
        }
    }

    private static double findBin(double data, double[] bins) {
        for (int i = 0; i < bins.length; ++i) {
            if (!(data < bins[i])) continue;
            return i;
        }
        String message = String.format("Bad bins created. Value: %f does not belong to any of the bins: %s", data, Arrays.toString(bins));
        LOG.error(message);
        throw new BinningException(message);
    }

    private double[] createBins(double[] data, int zeroInflatedCount, int validCount) {
        int currentPercentileIndex;
        ArrayList<Double> bins = new ArrayList<Double>();
        if (zeroInflatedCount > 0) {
            bins.add(data[zeroInflatedCount]);
        }
        if (validCount == zeroInflatedCount) {
            return Binner.asArray(bins);
        }
        int nextPercentileIndex = currentPercentileIndex = zeroInflatedCount;
        int binIdx = 1;
        while (binIdx < this.binCount) {
            while (binIdx < this.binCount && NumericUtils.equals((double)data[currentPercentileIndex], (double)data[nextPercentileIndex], (double)0.0)) {
                nextPercentileIndex = this.getPercentile(binIdx, zeroInflatedCount, validCount);
                ++binIdx;
            }
            if (NumericUtils.equals((double)data[currentPercentileIndex], (double)data[nextPercentileIndex], (double)0.0)) break;
            bins.add(data[nextPercentileIndex]);
            currentPercentileIndex = nextPercentileIndex;
        }
        if (bins.isEmpty() && !NumericUtils.equals((double)data[0], (double)data[validCount - 1], (double)0.0)) {
            double avg = (data[0] + data[validCount - 1]) / 2.0;
            bins.add(avg);
            LOG.debug(String.format("Only one bin, adding average value of %s.  bins=%s", avg, bins));
        }
        bins.add(Double.POSITIVE_INFINITY);
        return Binner.asArray(bins);
    }

    private static double[] asArray(List<Double> bins) {
        double[] a = new double[bins.size()];
        for (int i = 0; i < a.length; ++i) {
            a[i] = bins.get(i);
        }
        return a;
    }

    private int getPercentile(int binIdx, int start, int end) {
        int idx = start + (int)Math.ceil((double)binIdx * (double)(end - start) / (double)this.binCount);
        return idx - 1;
    }

    private static int countNonMissingValues(double[] sortedData) {
        if (sortedData.length == 0 || !NumericUtils.isMissingValue((double)sortedData[sortedData.length - 1])) {
            return sortedData.length;
        }
        int idx = Arrays.binarySearch(sortedData, Double.NaN);
        if (idx >= 0) {
            --idx;
            while (idx >= 0 && NumericUtils.isMissingValue((double)sortedData[idx])) {
                --idx;
            }
            return idx + 1;
        }
        return -idx - 1;
    }
}

