/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl;

import com.ibm.bi.predict.algorithms.ThreeLevelScale;
import com.ibm.bi.predict.algorithms.table.OneWayAnova;
import com.ibm.bi.predict.algorithms.table.results.AnovaResult;
import com.ibm.bi.predict.data.matrix.Matrix;
import com.ibm.bi.predict.data.matrix.MatrixVectorFactory;
import com.ibm.bi.predict.dataaccess.Decorator;
import com.ibm.bi.predict.exceptions.PredictException;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.sa.execution.annotation.ConditionalResponses;
import com.ibm.bi.predict.sa.execution.annotation.FieldRole;
import com.ibm.bi.predict.sa.execution.annotation.MessageServiceImpl;
import com.ibm.bi.predict.sa.execution.annotation.decorations.PredictiveStrengthDecoration;
import com.ibm.bi.predict.sa.execution.annotation.impl.AnnotationImpl;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.StatisticsMap;
import com.ibm.bi.predict.sa.execution.annotation.response.ConditionalResponse;
import com.ibm.bi.predict.sa.execution.annotation.result.AnnotationResult;
import com.ibm.bi.predict.sa.execution.annotation.result.PredictiveStrengthResult;
import com.ibm.bi.predict.sa.execution.api.DataRowAdapter;
import com.ibm.bi.predict.sa.execution.api.MetaDataAdapter;
import com.ibm.bi.predict.sa.execution.api.SuggestedAnnotation;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import com.spss.ac.acmath.accumstats.InteractTwoFactorsForContTarget;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;

public class TwoWayAnovaAnnotation
extends AnnotationImpl<PredictiveStrengthResult.PredictiveStrengthData> {
    private static final Logger log = PredictLoggerFactory.getLogger(TwoWayAnovaAnnotation.class);
    private static final double ADJ_R2_THRESHOLD_VALUE = 0.1;
    private static final double MODEL_IMPROVEMENT_THRESHOLD_VALUE = 0.1;
    private static final double SIGNIFICANCE_LEVEL = 0.05;
    private static final int MIN_CATS = 2;
    private final int factorACategoryCount;
    private final int factorBCategoryCount;
    private final int factorAIndex;
    private final int factorBIndex;
    private double pValue;
    private ThreeLevelScale strengthLevel;
    private double effectSize;
    private double adjustedRSquared;
    private List<Integer> significantFactors = Collections.emptyList();
    private List<Double> factorACounts;
    private List<Double> factorBCounts;
    private List<Double> factorAMeans;
    private List<Double> factorBMeans;
    private List<Double> factorASumSqrs;
    private List<Double> factorBSumSqrs;
    private Matrix counts;
    private Matrix means;
    private Matrix sumOfSquares;

    public TwoWayAnovaAnnotation(MetaDataAdapter metadata) {
        super(metadata);
        this.factorAIndex = this.factorAIndex();
        this.factorBIndex = this.factorBIndex();
        this.factorACategoryCount = metadata.getCountOfFieldCategories(this.factorAIndex);
        this.factorBCategoryCount = metadata.getCountOfFieldCategories(this.factorBIndex);
        this.factorACounts = new ArrayList<Double>(Collections.nCopies(this.factorACategoryCount, 0.0));
        this.factorBCounts = new ArrayList<Double>(Collections.nCopies(this.factorBCategoryCount, 0.0));
        this.factorAMeans = new ArrayList<Double>(Collections.nCopies(this.factorACategoryCount, 0.0));
        this.factorBMeans = new ArrayList<Double>(Collections.nCopies(this.factorBCategoryCount, 0.0));
        this.factorASumSqrs = new ArrayList<Double>(Collections.nCopies(this.factorACategoryCount, 0.0));
        this.factorBSumSqrs = new ArrayList<Double>(Collections.nCopies(this.factorBCategoryCount, 0.0));
        this.counts = MatrixVectorFactory.makeMatrix((int)this.factorACategoryCount, (int)this.factorBCategoryCount, (int)metadata.rowCount());
        this.means = MatrixVectorFactory.makeMatrix((int)this.factorACategoryCount, (int)this.factorBCategoryCount, (int)metadata.rowCount());
        this.sumOfSquares = MatrixVectorFactory.makeMatrix((int)this.factorACategoryCount, (int)this.factorBCategoryCount, (int)metadata.rowCount());
    }

    @Override
    public void update(DataRowAdapter dataRow) {
        double value = dataRow.getTargetValue();
        double cellFrequency = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        double sumSquares = dataRow.getTargetStatistic(StatisticsMap.StatisticName.SUM_OF_SQUARES);
        int factorACategory = (int)dataRow.getFieldValueByIndex(this.factorAIndex);
        int factorBCategory = (int)dataRow.getFieldValueByIndex(this.factorBIndex);
        double factorACount = this.updateCount(dataRow, this.factorAIndex);
        double factorBCount = this.updateCount(dataRow, this.factorBIndex);
        this.factorACounts.set(factorACategory, factorACount);
        this.factorBCounts.set(factorBCategory, factorBCount);
        double prevFactorAMean = this.factorAMeans.get(factorACategory);
        double prevFactorBMean = this.factorBMeans.get(factorBCategory);
        double newFactorAMean = this.updateMean(dataRow, this.factorAIndex, factorACount);
        double newFactorBMean = this.updateMean(dataRow, this.factorBIndex, factorBCount);
        this.factorAMeans.set(factorACategory, newFactorAMean);
        this.factorBMeans.set(factorBCategory, newFactorBMean);
        this.factorASumSqrs.set(factorACategory, this.updateSumOfSquares(dataRow, this.factorAIndex, prevFactorAMean, newFactorAMean));
        this.factorBSumSqrs.set(factorBCategory, this.updateSumOfSquares(dataRow, this.factorBIndex, prevFactorBMean, newFactorBMean));
        this.means.setValue(factorACategory, factorBCategory, value);
        this.counts.setValue(factorACategory, factorBCategory, cellFrequency);
        this.sumOfSquares.setValue(factorACategory, factorBCategory, sumSquares);
    }

    @Override
    public ConditionalResponse assertPreconditions() {
        if (this.metadata.getCountOfFieldCategories(this.factorAIndex) < 2) {
            return ConditionalResponses.TOO_FEW_CATEGORIES(this.metadata.getNameOfField(this.factorAIndex));
        }
        if (this.metadata.getCountOfFieldCategories(this.factorBIndex) < 2) {
            return ConditionalResponses.TOO_FEW_CATEGORIES(this.metadata.getNameOfField(this.factorBIndex));
        }
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public ConditionalResponse postUpdate() {
        log.perfStart();
        log.perfLog("Executing post-update of TwoWayAnovaAnnotation");
        AnovaResult twoWayResult = this.runTwoWayAnova();
        this.pValue = twoWayResult.pValue;
        log.perfLog("Computed two-way ANOVA");
        log.perfLog("Computing one-way ANOVA for factors");
        OneWayAnova factorAAnova = new OneWayAnova(this.factorAMeans, this.factorACounts, this.factorASumSqrs);
        AnovaResult anovaAResult = factorAAnova.compute();
        log.perfLog("Computed one-way ANOVA for factor A");
        OneWayAnova factorBAnova = new OneWayAnova(this.factorBMeans, this.factorBCounts, this.factorBSumSqrs);
        AnovaResult anovaBResult = factorBAnova.compute();
        log.perfLog("Computed one-way ANOVA for factor B");
        double fullModelAdjR2 = twoWayResult.adjustedRSquared;
        double mainEffectModelAAdjR2 = factorAAnova.getAdjustedRSquared();
        double mainEffectModelBAdjR2 = factorBAnova.getAdjustedRSquared();
        double maxMainEffectAdjR2 = Math.max(mainEffectModelAAdjR2, mainEffectModelBAdjR2);
        if (TwoWayAnovaAnnotation.isFullModelBetterThanFactorModel(fullModelAdjR2, maxMainEffectAdjR2) && TwoWayAnovaAnnotation.isSignificant(this.pValue)) {
            this.effectSize = twoWayResult.effectSize;
            this.adjustedRSquared = fullModelAdjR2;
            this.strengthLevel = TwoWayAnovaAnnotation.calculateStrengthLevel(this.adjustedRSquared);
            this.significantFactors = Arrays.asList(this.factorAIndex, this.factorBIndex);
        } else if (!(maxMainEffectAdjR2 <= 0.1) && (anovaAResult.hasPredictiveRelationship() || anovaBResult.hasPredictiveRelationship()) && (anovaAResult.hasPredictiveRelationship() || anovaBResult.hasPredictiveRelationship())) {
            OneWayAnova significantFactor;
            int factorIdx;
            if (anovaAResult.hasPredictiveRelationship()) {
                factorIdx = this.factorAIndex;
                significantFactor = factorAAnova;
            } else {
                factorIdx = this.factorBIndex;
                significantFactor = factorBAnova;
            }
            if (significantFactor.getAdjustedRSquared() > 0.1) {
                this.effectSize = significantFactor.getEffectSize();
                this.adjustedRSquared = significantFactor.getAdjustedRSquared();
                this.strengthLevel = TwoWayAnovaAnnotation.calculateStrengthLevel(this.adjustedRSquared);
                this.significantFactors = Arrays.asList(factorIdx);
            }
        }
        log.perfStop("Finished Two-way ANOVA");
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public AnnotationResult<PredictiveStrengthResult.PredictiveStrengthData> buildResult() {
        if (this.strengthLevel == null) {
            return PredictiveStrengthResult.noRelationshipResult(this.pValue, this.effectSize, this.metadata.getResponseName(), this.getFactorNames());
        }
        List<String> significantFactorNames = this.significantFactors.stream().map(this.metadata::getNameOfField).collect(Collectors.toList());
        return new PredictiveStrengthResult(new PredictiveStrengthResult.PredictiveStrengthData(this.pValue, this.effectSize, this.adjustedRSquared, this.strengthLevel, this.metadata.getResponseName(), significantFactorNames));
    }

    @Override
    public void decorate(Decorator decorator, SuggestedAnnotation annotationSuggestion, Locale locale) {
        AnnotationResult<PredictiveStrengthResult.PredictiveStrengthData> result = this.getResult();
        if (!result.isEmpty()) {
            new PredictiveStrengthDecoration(new MessageServiceImpl()).decorate(decorator, result, locale);
        }
    }

    private double updateMean(DataRowAdapter dataRow, int axis, double factorCount) {
        int factorIndex = (int)dataRow.getFieldValueByIndex(axis);
        double cellFrequency = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        double value = dataRow.getTargetValue();
        List<Double> factorMeans = axis == this.factorAIndex ? this.factorAMeans : this.factorBMeans;
        double prevFactorAMean = factorMeans.get(factorIndex);
        return prevFactorAMean + cellFrequency / factorCount * (value - prevFactorAMean);
    }

    private double updateCount(DataRowAdapter dataRow, int axis) {
        int factorIndex = (int)dataRow.getFieldValueByIndex(axis);
        double cellFrequency = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        List<Double> factorCounts = axis == this.factorAIndex ? this.factorACounts : this.factorBCounts;
        return factorCounts.get(factorIndex) + cellFrequency;
    }

    private double updateSumOfSquares(DataRowAdapter dataRow, int axis, double prevFactorMean, double newFactorMean) {
        int factorIndex = (int)dataRow.getFieldValueByIndex(axis);
        double sumSquares = dataRow.getTargetStatistic(StatisticsMap.StatisticName.SUM_OF_SQUARES);
        double cellFrequency = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        double value = dataRow.getTargetValue();
        List<Double> factorSumSqrs = axis == this.factorAIndex ? this.factorASumSqrs : this.factorBSumSqrs;
        double factorSumSqr = factorSumSqrs.get(factorIndex);
        return sumSquares + (factorSumSqr + cellFrequency * (value - prevFactorMean) * (value - newFactorMean));
    }

    private AnovaResult runTwoWayAnova() {
        int totalRowCount = (int)this.counts.sum();
        InteractTwoFactorsForContTarget twoWayAnova = new InteractTwoFactorsForContTarget(this.factorACategoryCount, this.factorBCategoryCount, this.counts.getDataAsVector(), this.means.getDataAsVector(), this.sumOfSquares.getDataAsVector(), totalRowCount, 0, (double)totalRowCount, (double)totalRowCount);
        twoWayAnova.computeStatistics();
        double returnedPValue = NumericUtils.isMissingValue((double)twoWayAnova.getPValue()) ? 0.0 : twoWayAnova.getPValue();
        return new AnovaResult(returnedPValue, twoWayAnova.getEffectSize(), twoWayAnova.getFitMeasure(), twoWayAnova.getFStat(), TwoWayAnovaAnnotation.calculateStrengthLevel(twoWayAnova.getFitMeasure()));
    }

    private static boolean isSignificant(double pValue) {
        return pValue <= 0.05;
    }

    private static boolean isFullModelBetterThanFactorModel(double fullModelAdjR2, double maxMainEffectAdjR2) {
        return fullModelAdjR2 > Math.max(0.1, maxMainEffectAdjR2 + 0.1 * (1.0 - maxMainEffectAdjR2));
    }

    private static ThreeLevelScale calculateStrengthLevel(double adjRSquared) {
        if (adjRSquared <= 0.35) {
            return ThreeLevelScale.LOW;
        }
        if (adjRSquared <= 0.7) {
            return ThreeLevelScale.MEDIUM;
        }
        return ThreeLevelScale.HIGH;
    }

    private List<String> getFactorNames() {
        return Arrays.asList(this.metadata.getNameOfField(this.factorAIndex), this.metadata.getNameOfField(this.factorBIndex));
    }

    private int factorAIndex() {
        return this.metadata.getExplanatoryFieldIndices().get(0);
    }

    private int factorBIndex() {
        List<Integer> explanatoryIndices = this.metadata.getExplanatoryFieldIndices();
        List<Integer> groupIndices = this.metadata.getFieldIndicesByRole(FieldRole.GROUP);
        if (explanatoryIndices.size() == 2) {
            return explanatoryIndices.get(1);
        }
        if (explanatoryIndices.size() == 1 && groupIndices.size() == 1) {
            return groupIndices.get(0);
        }
        throw new PredictException("Unknown state. Expected 2 explanatory fields or 1 explanatory, 1 group");
    }
}

