/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl;

import com.ibm.bi.predict.algorithms.table.results.InfluentialCategory;
import com.ibm.bi.predict.data.matrix.Matrix;
import com.ibm.bi.predict.data.matrix.MatrixVectorFactory;
import com.ibm.bi.predict.dataaccess.Decorator;
import com.ibm.bi.predict.dataaccess.types.AggregationType;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.exceptions.PredictException;
import com.ibm.bi.predict.sa.execution.annotation.ConditionalResponses;
import com.ibm.bi.predict.sa.execution.annotation.DataRowValidator;
import com.ibm.bi.predict.sa.execution.annotation.FieldRole;
import com.ibm.bi.predict.sa.execution.annotation.MessageServiceImpl;
import com.ibm.bi.predict.sa.execution.annotation.decorations.MeaningfulDifferencesDecoration;
import com.ibm.bi.predict.sa.execution.annotation.impl.AnnotationImpl;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.ChiSquare;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.MeaningfulDifferences;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.StatisticsMap;
import com.ibm.bi.predict.sa.execution.annotation.impl.types.CategoryName;
import com.ibm.bi.predict.sa.execution.annotation.response.ConditionalResponse;
import com.ibm.bi.predict.sa.execution.annotation.result.AnnotationResult;
import com.ibm.bi.predict.sa.execution.annotation.result.EmptyResult;
import com.ibm.bi.predict.sa.execution.annotation.result.MeaningfulDifferencesResult;
import com.ibm.bi.predict.sa.execution.api.DataRowAdapter;
import com.ibm.bi.predict.sa.execution.api.MetaDataAdapter;
import com.ibm.bi.predict.sa.execution.api.SuggestedAnnotation;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

public class TwoWayChiSquareAnnotation
extends AnnotationImpl<List<MeaningfulDifferencesResult.MeaningfulDifferencesData>>
implements DataRowValidator {
    private static final Logger log = PredictLoggerFactory.getLogger(TwoWayChiSquareAnnotation.class);
    private Matrix values;
    private Map<Integer, Integer> indices = new HashMap<Integer, Integer>();
    private final AggregationType aggregationType;
    private StatisticsMap statistics = new StatisticsMap();
    private List<InfluentialCategory> outliers;
    private int numberOfCategories = 0;
    private double threshold = 0.1;

    public TwoWayChiSquareAnnotation(MetaDataAdapter metadata) {
        super(metadata);
        this.aggregationType = this.getAggregation();
        int factorACatCount = this.factorACatCount();
        int factorBCatCount = this.factorBCatCount();
        if (factorACatCount <= 0 || factorBCatCount <= 0) {
            return;
        }
        this.values = MatrixVectorFactory.makeMatrix((int)factorACatCount, (int)factorBCatCount, (int)metadata.rowCount());
        if (this.aggregationType == AggregationType.SUM) {
            this.statistics.addStatistic(StatisticsMap.StatisticName.ROW_COUNT, MatrixVectorFactory.makeMatrix((int)factorACatCount, (int)factorBCatCount, (int)metadata.rowCount()));
            this.statistics.addStatistic(StatisticsMap.StatisticName.SUM_OF_SQUARES, MatrixVectorFactory.makeMatrix((int)factorACatCount, (int)factorBCatCount, (int)metadata.rowCount()));
        }
    }

    @Override
    public ConditionalResponse assertPreconditions() {
        return this.hasEnoughCategories() ? ConditionalResponses.SUCCESS : this.tooFewCategories();
    }

    @Override
    public void update(DataRowAdapter dataRow) {
        int factorA = this.factorAValue(dataRow);
        int factorB = this.factorBValue(dataRow);
        double value = dataRow.getFieldValueByIndex(this.metadata.getResponseIndex());
        this.values.setValue(factorA, factorB, value);
        this.indices.put(this.getTrueIndex(factorA, factorB, this.values.rowDimension()), dataRow.getDataRowIndex());
        ++this.numberOfCategories;
        if (this.aggregationType == AggregationType.SUM) {
            this.statistics.updateStatistic(StatisticsMap.StatisticName.ROW_COUNT, factorA, factorB, dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT));
            this.statistics.updateStatistic(StatisticsMap.StatisticName.SUM_OF_SQUARES, factorA, factorB, dataRow.getTargetStatistic(StatisticsMap.StatisticName.SUM_OF_SQUARES));
        }
    }

    @Override
    public ConditionalResponse postUpdate() {
        log.debug("Beginning to compute two-way chi square");
        log.perfStart();
        this.outliers = this.aggregationType == AggregationType.SUM ? this.outliersForSumAggregation() : this.outliersForCountAggregation();
        log.perfLog("Completed one-way chi square - numOutliers={}", (Object)this.outliers.size());
        log.perfStop();
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public AnnotationResult<List<MeaningfulDifferencesResult.MeaningfulDifferencesData>> buildResult() {
        if (this.outliers.isEmpty()) {
            return new EmptyResult<List<MeaningfulDifferencesResult.MeaningfulDifferencesData>>();
        }
        List<MeaningfulDifferencesResult.MeaningfulDifferencesData> annotations = this.outliers.stream().map(result -> {
            if (this.shouldExcludeOutlier((InfluentialCategory)result)) {
                return null;
            }
            int index = this.indices.get(this.getTrueIndex((Integer)result.categoryIndex._1, (Integer)result.categoryIndex._2, this.values.rowDimension()));
            return new MeaningfulDifferencesResult.MeaningfulDifferencesData(index, result.direction.toString().toLowerCase(), result.expected, result.pValue, this.getCategoryNames((Integer)result.categoryIndex._1, (Integer)result.categoryIndex._2), this.getResponseName(), this.getFactorNames(), Optional.of(this.aggregationType));
        }).filter(Objects::nonNull).collect(Collectors.toList());
        if (annotations.isEmpty()) {
            return new EmptyResult<List<MeaningfulDifferencesResult.MeaningfulDifferencesData>>();
        }
        return new MeaningfulDifferencesResult(annotations);
    }

    @Override
    public void decorate(Decorator decorator, SuggestedAnnotation annotationSuggestion, Locale locale) {
        int indexToBeDecorated;
        AnnotationResult<List<MeaningfulDifferencesResult.MeaningfulDifferencesData>> result = this.getResult();
        int n = indexToBeDecorated = this.responseIsAggregated() ? this.metadata.getResponseIndex() : this.aggregatedExplanatoryIndex();
        if (!result.isEmpty()) {
            new MeaningfulDifferencesDecoration(indexToBeDecorated, new MessageServiceImpl(), this.aggregationType).decorate(decorator, result, locale);
        }
    }

    @Override
    public ConditionalResponse validateDataRow(DataRowAdapter dataRow) {
        return ConditionalResponses.SUCCESS;
    }

    private List<CategoryName> getCategoryNames(int factorACategoryIdx, int factorBCategoryIdx) {
        return Arrays.asList(CategoryName.fromString(this.metadata.getCategoryNameForField(this.factorAIndex(), factorACategoryIdx)), this.factorBCategoryName(factorBCategoryIdx));
    }

    private boolean hasEnoughCategories() {
        int responseCategoryCount = this.factorACatCount();
        int factorCategoryCount = this.factorBCatCount();
        if (responseCategoryCount >= 2 && factorCategoryCount >= 2) {
            return true;
        }
        return responseCategoryCount < 2 && factorCategoryCount >= 3 || responseCategoryCount >= 3 && factorCategoryCount < 2;
    }

    private ConditionalResponse tooFewCategories() {
        int responseCategoryCount = this.factorACatCount();
        int factorCategoryCount = this.factorBCatCount();
        String name = responseCategoryCount < 2 && factorCategoryCount < 2 ? this.metadata.getNameOfField(this.factorAIndex()) : (responseCategoryCount < 2 && factorCategoryCount < 3 ? this.factorBName() : (factorCategoryCount < 2 && responseCategoryCount < 3 ? this.metadata.getNameOfField(this.factorAIndex()) : this.metadata.getResponseName()));
        return ConditionalResponses.TOO_FEW_CATEGORIES(name);
    }

    private String factorBName() {
        if (this.factorBIsCompound()) {
            List<Integer> catRepeatIndices = this.categoricalRepeatFields();
            int rep1 = catRepeatIndices.get(0);
            int rep2 = catRepeatIndices.get(1);
            return this.metadata.getNameOfField(rep1) + " | " + this.metadata.getNameOfField(rep2);
        }
        return this.metadata.getNameOfField(this.factorBIndex());
    }

    private List<String> getFactorNames() {
        if (this.factorBIsCompound()) {
            List<Integer> catRepeatIndices = this.categoricalRepeatFields();
            int rep1 = catRepeatIndices.get(0);
            int rep2 = catRepeatIndices.get(1);
            return Arrays.asList(this.metadata.getNameOfField(this.factorAIndex()), this.metadata.getNameOfField(rep1) + " | " + this.metadata.getNameOfField(rep2));
        }
        return Arrays.asList(this.metadata.getNameOfField(this.factorAIndex()), this.metadata.getNameOfField(this.factorBIndex()));
    }

    private CategoryName factorBCategoryName(int catIndex) {
        if (this.factorBIsCompound()) {
            List<Integer> catRepeatIndices = this.categoricalRepeatFields();
            int rep1 = catRepeatIndices.get(0);
            int rep2 = catRepeatIndices.get(1);
            int nCatsRep2 = this.metadata.getCountOfFieldCategories(rep2);
            int rep1Idx = catIndex / nCatsRep2;
            int rep2Idx = catIndex % nCatsRep2;
            return CategoryName.compound(this.metadata.getCategoryNameForField(rep1, rep1Idx), this.metadata.getCategoryNameForField(rep2, rep2Idx));
        }
        return CategoryName.fromString(this.metadata.getCategoryNameForField(this.factorBIndex(), catIndex));
    }

    private AggregationType getAggregation() {
        return this.getAggregationTypeOfField(this.metadata.getResponseIndex());
    }

    private AggregationType getAggregationTypeOfField(int fieldIdx) {
        return this.metadata.getAggregationTypeOfField(fieldIdx).orElseThrow(() -> new IllegalArgumentException("Expected field to have aggregation"));
    }

    private boolean responseIsAggregated() {
        return this.metadata.getAggregationTypeOfField(this.metadata.getResponseIndex()).isPresent();
    }

    private boolean factorBIsCompound() {
        return this.getState() == State.ONE_EXPL_TWO_REPEAT;
    }

    private int factorAIndex() {
        return this.categoricalExplanatoryFields().get(0);
    }

    private int factorBIndex() {
        List<Integer> explanatories = this.categoricalExplanatoryFields();
        List<Integer> repeats = this.categoricalRepeatFields();
        List<Integer> groups = this.fieldsForRole(FieldRole.GROUP, FieldType.CATEGORICAL);
        switch (this.getState()) {
            case TWO_EXPL: {
                return explanatories.get(1);
            }
            case ONE_EXPL_ONE_REPEAT: {
                return repeats.get(0);
            }
            case ONE_EXPL_ONE_GROUP: {
                return groups.get(0);
            }
        }
        throw new IllegalStateException("Unknown state");
    }

    private List<Integer> categoricalExplanatoryFields() {
        return this.metadata.getExplanatoryFieldIndices().stream().filter(i -> this.metadata.getFieldType((int)i) == FieldType.CATEGORICAL || this.metadata.getFieldType((int)i) == FieldType.DATETIME).collect(Collectors.toList());
    }

    private List<Integer> categoricalRepeatFields() {
        return this.metadata.getFieldIndicesByRole(FieldRole.REPEAT).stream().filter(idx -> this.metadata.getFieldType((int)idx) == FieldType.CATEGORICAL).collect(Collectors.toList());
    }

    private int aggregatedExplanatoryIndex() {
        return this.metadata.getExplanatoryFieldIndices().stream().filter(i -> this.metadata.getAggregationTypeOfField((int)i).isPresent()).findFirst().orElseThrow(PredictException::new);
    }

    private boolean shouldExcludeOutlier(InfluentialCategory result) {
        int trueIndex = this.getTrueIndex((Integer)result.categoryIndex._1, (Integer)result.categoryIndex._2, this.values.rowDimension());
        return !this.indices.containsKey(trueIndex);
    }

    private List<InfluentialCategory> outliersForSumAggregation() {
        return MeaningfulDifferences.detect(this.values, AggregationType.SUM, this.numberOfCategories, this.statistics);
    }

    private List<InfluentialCategory> outliersForCountAggregation() {
        ChiSquare chiSquareTest = this.factorBCatCount() == 1 ? new ChiSquare(this.aggregationType, this.values.transpose(), true, this.threshold) : new ChiSquare(this.aggregationType, this.values, false, this.threshold);
        chiSquareTest.compute();
        return chiSquareTest.getOutliers();
    }

    private int getTrueIndex(int rowIdx, int columnIdx, int numRows) {
        return columnIdx * numRows + rowIdx;
    }

    private int factorACatCount() {
        return this.catCount(this.categoricalExplanatoryFields().get(0));
    }

    private int factorBCatCount() {
        List<Integer> explanatories = this.categoricalExplanatoryFields();
        List<Integer> repeats = this.categoricalRepeatFields();
        List<Integer> groups = this.fieldsForRole(FieldRole.GROUP, FieldType.CATEGORICAL);
        switch (this.getState()) {
            case TWO_EXPL: {
                return this.catCount(explanatories.get(1));
            }
            case ONE_EXPL_ONE_REPEAT: {
                return this.catCount(repeats.get(0));
            }
            case ONE_EXPL_ONE_GROUP: {
                return this.catCount(groups.get(0));
            }
            case ONE_EXPL_TWO_REPEAT: {
                return this.catCount(repeats.get(0)) * this.catCount(repeats.get(1));
            }
        }
        throw new IllegalStateException("Unknown state");
    }

    private List<Integer> fieldsForRole(FieldRole role, FieldType type) {
        return this.metadata.getFieldIndicesByRole(role).stream().filter(i -> this.metadata.getFieldType((int)i) == type).collect(Collectors.toList());
    }

    private int catCount(int index) {
        return this.metadata.getCountOfFieldCategories(index);
    }

    private int factorAValue(DataRowAdapter dataRow) {
        return (int)dataRow.getFieldValueByIndex(this.categoricalExplanatoryFields().get(0));
    }

    private int factorBValue(DataRowAdapter dataRow) {
        List<Integer> explanatories = this.categoricalExplanatoryFields();
        List<Integer> repeats = this.categoricalRepeatFields();
        List<Integer> groups = this.fieldsForRole(FieldRole.GROUP, FieldType.CATEGORICAL);
        switch (this.getState()) {
            case TWO_EXPL: {
                return (int)dataRow.getFieldValueByIndex(explanatories.get(1));
            }
            case ONE_EXPL_ONE_REPEAT: {
                return (int)dataRow.getFieldValueByIndex(repeats.get(0));
            }
            case ONE_EXPL_ONE_GROUP: {
                return (int)dataRow.getFieldValueByIndex(groups.get(0));
            }
            case ONE_EXPL_TWO_REPEAT: {
                return (int)dataRow.getFieldValueByIndex(repeats.get(0)) * this.catCount(repeats.get(1)) + (int)dataRow.getFieldValueByIndex(repeats.get(1));
            }
        }
        throw new IllegalStateException("Unknown state");
    }

    private String getResponseName() {
        if (this.aggregationType == AggregationType.SUM) {
            return this.metadata.getResponseName();
        }
        return this.metadata.getNameOfField(this.categoricalExplanatoryFields().get(0));
    }

    private State getState() {
        List<Integer> explanatories = this.categoricalExplanatoryFields();
        List<Integer> repeats = this.categoricalRepeatFields();
        List<Integer> groups = this.fieldsForRole(FieldRole.GROUP, FieldType.CATEGORICAL);
        if (explanatories.size() == 2) {
            return State.TWO_EXPL;
        }
        if (explanatories.size() == 1 && repeats.size() == 1) {
            return State.ONE_EXPL_ONE_REPEAT;
        }
        if (this.aggregationType == AggregationType.SUM && explanatories.size() == 1 && groups.size() == 1) {
            return State.ONE_EXPL_ONE_GROUP;
        }
        if (this.aggregationType == AggregationType.COUNT && explanatories.size() == 1 && repeats.size() == 2) {
            return State.ONE_EXPL_TWO_REPEAT;
        }
        throw new IllegalStateException("Unknown state");
    }

    private static enum State {
        TWO_EXPL,
        ONE_EXPL_ONE_GROUP,
        ONE_EXPL_ONE_REPEAT,
        ONE_EXPL_TWO_REPEAT;

    }
}

