/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.table;

import com.ibm.bi.predict.algorithms.ThreeLevelScale;
import com.ibm.bi.predict.algorithms.table.results.InfluentialCategory;
import com.ibm.bi.predict.data.matrix.Matrix;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.math.TopNSelector;
import com.spss.math.statistics.DistributionFunctions;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.mutable.MutableInt;
import org.apache.commons.math3.stat.StatUtils;

public class ChiSquareForSum {
    private Matrix values;
    private Matrix counts;
    private Matrix sumOfSquares;
    private int numOfCategories;
    private double totalFrequency = 0.0;

    public ChiSquareForSum(Matrix values, Matrix counts, Matrix sumOfSquares, int numberOfCategories) {
        this.values = values;
        this.counts = counts;
        this.sumOfSquares = sumOfSquares;
        this.numOfCategories = numberOfCategories;
    }

    public List<InfluentialCategory> compute() {
        this.totalFrequency = this.counts.sum();
        double meanSum = this.meanOfSums();
        double variance = this.estimatedVariance();
        if (NumericUtils.isMissingValue((double)variance)) {
            return Collections.emptyList();
        }
        double pValue = this.calculateChiSquare(variance, meanSum);
        ArrayList<InfluentialCategory> unusualCategories = new ArrayList<InfluentialCategory>();
        if (this.isSignificant(pValue)) {
            MutableInt idx = new MutableInt();
            this.values.walkNonZero((row, col, val) -> {
                double value = val;
                double count = this.counts.getValue(row, col);
                if (NumericUtils.isZero((double)count)) {
                    return;
                }
                double zStat = (value - meanSum) / Math.sqrt(variance);
                double adjustedPValue = 2.0 * ((1.0 - this.cdfStandardNorm(Math.abs(zStat))) * (double)this.numOfCategories);
                this.addUnusualCategory(meanSum, unusualCategories, idx, row, col, value, this.totalFrequency, zStat, adjustedPValue);
                idx.increment();
            });
        }
        return this.filter(unusualCategories);
    }

    private void addUnusualCategory(double meanSum, List<InfluentialCategory> unusualCategories, MutableInt idx, int row, int col, double value, double totalCount, double zStat, double adjustedPValue) {
        if (this.isSignificant(adjustedPValue)) {
            double effectSize = Math.abs(zStat) / Math.sqrt(totalCount);
            ThreeLevelScale direction = value < meanSum ? ThreeLevelScale.LOW : ThreeLevelScale.HIGH;
            unusualCategories.add(new InfluentialCategory(idx.toInteger(), row, col, adjustedPValue, effectSize, meanSum, direction));
        }
    }

    private double estimatedVariance() {
        double mean = this.overallMean();
        double stdDev = this.standardDeviation();
        double variance = this.totalFrequency / (double)this.numOfCategories * (Math.pow(stdDev, 2.0) + Math.pow(mean, 2.0) * (double)(this.numOfCategories - 1) / (double)this.numOfCategories);
        if (Double.isNaN(variance) || Double.isInfinite(variance)) {
            return Double.NaN;
        }
        return variance;
    }

    private double calculateChiSquare(double variance, double meanSum) {
        double chiSquareStat = this.chiSquare(variance, meanSum);
        double pValue = 1.0 - this.cdfChi(chiSquareStat, (double)this.numOfCategories - 1.0);
        return pValue;
    }

    private double standardDeviation() {
        double totalSumOfSquares = this.sumOfSquares.sum();
        if (NumericUtils.isZero((double)totalSumOfSquares)) {
            return 0.0;
        }
        return Math.sqrt(totalSumOfSquares / (this.totalFrequency - (double)this.numOfCategories));
    }

    private double chiSquare(double variance, double meanSum) {
        double sumForNonZero = this.values.sumNonZero((row, col, val) -> Math.pow(val - meanSum, 2.0) / variance);
        int countCellsWithZero = this.values.columnDimension() * this.values.rowDimension() - this.values.countNonZero();
        double sumForZero = (double)countCellsWithZero * meanSum * meanSum / variance;
        return sumForZero + sumForNonZero;
    }

    private double meanOfSums() {
        return this.values.sum() / (double)this.numOfCategories;
    }

    private double overallMean() {
        return this.values.sum() / this.totalFrequency;
    }

    private boolean isSignificant(double pValue) {
        return pValue <= 0.05;
    }

    private List<InfluentialCategory> filter(List<InfluentialCategory> influentialCategories) {
        List sortedInfluentialCategories = influentialCategories.stream().sorted((c1, c2) -> Double.compare(c2.effectSize, c1.effectSize)).collect(Collectors.toList());
        double[] effectSizes = sortedInfluentialCategories.stream().mapToDouble(s -> s.effectSize).toArray();
        int topN = TopNSelector.selectTopN((double[])effectSizes, (int)this.numOfCategories);
        return sortedInfluentialCategories.subList(0, topN);
    }

    public double sumList(double[][] values) {
        return Arrays.stream(values).mapToDouble(StatUtils::sum).reduce(0.0, (acc, a) -> acc + a);
    }

    private double cdfChi(double chiSquareStat, double degreesOfFreedom) {
        return Double.isNaN(chiSquareStat) ? Double.NaN : DistributionFunctions.cdfChi((double)chiSquareStat, (double)degreesOfFreedom);
    }

    private double cdfStandardNorm(double zStatistic) {
        return Double.isNaN(zStatistic) ? Double.NaN : DistributionFunctions.cdfStdNorm((double)zStatistic);
    }
}

