/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl.math;

import com.ibm.bi.predict.algorithms.ThreeLevelScale;
import com.ibm.bi.predict.algorithms.table.AbstractChiSquareTestForCategTarget;
import com.ibm.bi.predict.algorithms.table.ChiSquareTest;
import com.ibm.bi.predict.algorithms.table.FrequencyChiSquareTest;
import com.ibm.bi.predict.algorithms.table.NonFrequencyChiSquareTest;
import com.ibm.bi.predict.algorithms.table.results.ChiSquareTestResult;
import com.ibm.bi.predict.algorithms.table.results.InfluentialCategory;
import com.ibm.bi.predict.data.matrix.Matrix;
import com.ibm.bi.predict.dataaccess.types.AggregationType;
import com.ibm.bi.predict.math.TopNSelector;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

public class ChiSquare {
    private static final Logger log = PredictLoggerFactory.getLogger(ChiSquare.class);
    private AggregationType aggregation;
    private Matrix data;
    private ChiSquareTest chiSquareTest;
    private List<InfluentialCategory> outliers = Collections.emptyList();
    private final boolean isTransposed;
    private double adjustedChiSquare;

    public ChiSquare(AggregationType aggregation, Matrix data) {
        this(aggregation, data, false);
    }

    public ChiSquare(AggregationType aggregation, Matrix data, boolean isTransposed) {
        this(aggregation, data, isTransposed, 0.0);
    }

    public ChiSquare(AggregationType aggregation, Matrix data, boolean isTransposed, double threshold) {
        this.aggregation = aggregation;
        this.data = data;
        this.isTransposed = isTransposed;
        if (threshold > 0.0) {
            this.adjustedChiSquare = AbstractChiSquareTestForCategTarget.getAdjustValue((Matrix)data, (double)threshold);
        }
    }

    public double compute() {
        log.perfStart();
        double overallPValue = this.computeChiSquare();
        if (overallPValue <= 0.05) {
            this.determineInfluentialCategories();
        }
        log.perfStop();
        return overallPValue;
    }

    public List<InfluentialCategory> getOutliers() {
        if (this.isTransposed) {
            return this.outliers.stream().map(o -> new InfluentialCategory(o.cellIndex, ((Integer)o.categoryIndex._2).intValue(), ((Integer)o.categoryIndex._1).intValue(), o.pValue, o.effectSize, o.expected, o.direction)).collect(Collectors.toList());
        }
        return this.outliers;
    }

    private double computeChiSquare() {
        this.chiSquareTest = this.getChiSquareTest();
        log.perfLog("Starting computation of chi-square");
        double overallPValue = this.chiSquareTest.computeOverallChiSquare().pValue;
        log.perfLog("Completed computation of chi-square");
        return overallPValue;
    }

    private void determineInfluentialCategories() {
        List significantPoints = this.chiSquareTest.computeChiSquareStatisticsForCells().stream().filter(c -> c.pValue <= 0.05).sorted((c1, c2) -> Double.compare(c2.effectSize, c1.effectSize)).collect(Collectors.toList());
        double[] effectSizes = significantPoints.stream().mapToDouble(s -> s.effectSize).toArray();
        int topN = TopNSelector.selectTopN((double[])effectSizes, (int)this.totalDataPointCount());
        log.perfLog("Completed influential cell detection");
        ArrayList<InfluentialCategory> influentialCategories = new ArrayList<InfluentialCategory>();
        for (int i = 0; i < topN; ++i) {
            influentialCategories.add(this.resultToInfluentialCategory((ChiSquareTestResult)significantPoints.get(i), i));
        }
        this.outliers = influentialCategories;
    }

    private InfluentialCategory resultToInfluentialCategory(ChiSquareTestResult result, int index) {
        double observedValue = this.data.getValue(result.responseFieldIndex, result.explanatoryFieldIndex);
        ThreeLevelScale direction = observedValue > result.expectedCount ? ThreeLevelScale.HIGH : ThreeLevelScale.LOW;
        return new InfluentialCategory(index, result.responseFieldIndex, result.explanatoryFieldIndex, result.pValue, result.effectSize, result.expectedCount, direction);
    }

    private ChiSquareTest getChiSquareTest() {
        if (this.isCountAggregation()) {
            return new FrequencyChiSquareTest(this.data, this.adjustedChiSquare);
        }
        return new NonFrequencyChiSquareTest(this.data);
    }

    private int totalDataPointCount() {
        return this.data.rowDimension() * this.data.columnDimension();
    }

    private boolean isCountAggregation() {
        return this.aggregation == AggregationType.COUNT;
    }
}

