/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.tree.eval;

import com.ibm.bi.predict.algorithms.tree.DecisionTree;
import com.ibm.bi.predict.algorithms.tree.eval.EvaluationStatistic;
import com.ibm.bi.predict.algorithms.tree.util.TreeParameters;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.Collections;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class ModelEvaluation {
    private static final Logger LOG = PredictLoggerFactory.getLogger(ModelEvaluation.class);
    private final DecisionTree tree;
    private final boolean numericTarget;
    private final int[] rows;
    private final Map<EvaluationStatistic, Double> statistics;
    private LinkedHashMap<DataColumn, Double> fieldImportance;

    public ModelEvaluation(DecisionTree tree, int[] rows) {
        this.tree = tree;
        this.numericTarget = tree.target().getType() != FieldType.CATEGORICAL;
        this.rows = rows;
        this.statistics = new EnumMap<EvaluationStatistic, Double>(EvaluationStatistic.class);
    }

    public static int[][] partitionRowsIntoTrainingAndTest(int count, TreeParameters config) {
        int nTest = (int)Math.round((double)count * config.getTestFraction());
        int nTrain = count - nTest;
        List rows = IntStream.range(0, count).boxed().collect(Collectors.toList());
        Collections.shuffle(rows, config.makeRandom());
        return new int[][]{rows.subList(0, nTrain).stream().mapToInt(i -> i).sorted().toArray(), rows.subList(nTrain, count).stream().mapToInt(i -> i).sorted().toArray()};
    }

    public Map<DataColumn, Double> fieldImportance() {
        if (this.fieldImportance == null) {
            this.fieldImportance = this.calculateFieldImportance();
        }
        return this.fieldImportance;
    }

    public String toString() {
        return this.statistics().entrySet().stream().map(e -> String.format("%s:%3.3f", e.getKey(), e.getValue())).collect(Collectors.joining(", "));
    }

    public Double valueOf(EvaluationStatistic statistic) {
        return this.statistics().get((Object)statistic);
    }

    private LinkedHashMap<DataColumn, Double> calculateFieldImportance() {
        Set used = this.tree.nodes().map(DecisionTree::splitField).filter(Objects::nonNull).collect(Collectors.toSet());
        Map<DataColumn, Double> map = used.stream().collect(Collectors.toMap(c -> c, c -> this.fieldImportance((DataColumn)c, used)));
        return map.entrySet().stream().sorted(this::compareByValue).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (u, v) -> u, LinkedHashMap::new));
    }

    private int compareByValue(Map.Entry<DataColumn, Double> a, Map.Entry<DataColumn, Double> b) {
        int d = Double.compare(b.getValue(), a.getValue());
        return d != 0 ? d : a.getKey().getId().compareTo(b.getKey().getId());
    }

    private Map<EvaluationStatistic, Double> statistics() {
        if (!this.statistics.containsKey((Object)EvaluationStatistic.ROW_COUNT)) {
            this.statistics.put(EvaluationStatistic.ROW_COUNT, Double.valueOf(this.rows.length));
            if (this.rows.length == 0) {
                return this.statistics;
            }
            DataColumn target = this.tree.target();
            double[] actual = new double[this.rows.length];
            for (int i = 0; i < actual.length; ++i) {
                actual[i] = target.getValue(this.rows[i]);
            }
            double[] expected = ModelEvaluation.score(this.tree, this.rows);
            if (this.numericTarget) {
                this.evaluateForNumericTarget(expected, actual);
            } else {
                this.evaluateForCategoricalTarget(expected, actual);
            }
        }
        return this.statistics;
    }

    private void evaluateForCategoricalTarget(double[] expected, double[] actual) {
        int correctCount = 0;
        int n = actual.length;
        HashMap<Double, Integer> correctByClass = new HashMap<Double, Integer>();
        HashMap<Double, Integer> countByClass = new HashMap<Double, Integer>();
        for (int i2 = 0; i2 < n; ++i2) {
            double v = actual[i2];
            boolean correct = NumericUtils.equals((double)expected[i2], (double)v);
            if (correct) {
                ++correctCount;
                correctByClass.put(v, correctByClass.getOrDefault(v, 0) + 1);
            }
            countByClass.put(v, countByClass.getOrDefault(v, 0) + 1);
        }
        int mostCommonCount = countByClass.values().stream().mapToInt(i -> i).max().orElse(-1);
        double nullHypothesisErrors = n - mostCommonCount;
        double modelErrors = n - correctCount;
        double accuracyByClass = 0.0;
        for (Map.Entry c : countByClass.entrySet()) {
            accuracyByClass += (double)correctByClass.getOrDefault(c.getKey(), 0).intValue() / (double)((Integer)countByClass.get(c.getKey())).intValue();
        }
        this.statistics.put(EvaluationStatistic.CLASSIFICATION_ACCURACY, (double)correctCount / (double)n);
        this.statistics.put(EvaluationStatistic.BALANCED_CLASSIFICATION_ACCURACY, accuracyByClass / (double)countByClass.size());
        this.statistics.put(EvaluationStatistic.CLASSIFICATION_IMPROVEMENT, 1.0 - modelErrors / nullHypothesisErrors);
        this.statistics.put(EvaluationStatistic.CLASSIFICATION_ERRORS, modelErrors);
        this.statistics.put(EvaluationStatistic.NULL_HYPOTHESIS_ERRORS, nullHypothesisErrors);
    }

    private void evaluateForNumericTarget(double[] expected, double[] actual) {
        double eMean = ModelEvaluation.mean(expected);
        double aMean = ModelEvaluation.mean(actual);
        double sxy = 0.0;
        double sxx = 0.0;
        double syy = 0.0;
        double ssr = 0.0;
        for (int i = 0; i < actual.length; ++i) {
            ssr += (expected[i] - actual[i]) * (expected[i] - actual[i]);
            double e = expected[i] - eMean;
            double o = actual[i] - aMean;
            sxx += o * o;
            syy += e * e;
            sxy += o * e;
        }
        this.statistics.put(EvaluationStatistic.RESIDUAL_SUM_OF_SQUARES, ssr);
        this.statistics.put(EvaluationStatistic.CORRELATION, sxy / Math.sqrt(sxx * syy));
        this.statistics.put(EvaluationStatistic.FRACTION_VARIANCE_EXPLAINED, 1.0 - ssr / sxx);
    }

    private double fieldImportance(DataColumn predictor, Set<DataColumn> predictors) {
        double score;
        LinkedHashSet<DataColumn> leaveOutOnePredictors = new LinkedHashSet<DataColumn>(predictors);
        leaveOutOnePredictors.remove(predictor);
        LOG.debug(() -> String.format("Calculating field importance leaving out %s from %s", predictor.toString(), predictors.toString()));
        ModelEvaluation restricted = this.tree.withDifferentPredictors(leaveOutOnePredictors).evaluateOnTrainingData();
        if (this.numericTarget) {
            double fullModelSSR = this.valueOf(EvaluationStatistic.RESIDUAL_SUM_OF_SQUARES);
            double restrictedSSR = restricted.valueOf(EvaluationStatistic.RESIDUAL_SUM_OF_SQUARES);
            score = (restrictedSSR - fullModelSSR) / restrictedSSR;
        } else {
            double fullModelAccuracy = this.valueOf(EvaluationStatistic.CLASSIFICATION_IMPROVEMENT);
            double restrictedAccuracy = restricted.valueOf(EvaluationStatistic.CLASSIFICATION_IMPROVEMENT);
            score = (fullModelAccuracy - restrictedAccuracy) / (1.0 - restrictedAccuracy);
        }
        if (Math.abs(score) < 1.0E-12) {
            score = 0.0;
        }
        return Math.max(score, 0.0);
    }

    private static double mean(double[] values) {
        double d = 0.0;
        for (double value : values) {
            d += value;
        }
        return d / (double)values.length;
    }

    private static double[] score(DecisionTree tree, int[] rows) {
        Map<String, DataColumn> predictorMap = tree.predictors().stream().collect(Collectors.toMap(DataColumn::getId, p -> p));
        return tree.score((s, i) -> ((DataColumn)predictorMap.get(s)).getValue(i.intValue()), rows);
    }
}

