/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl;

import com.ibm.bi.predict.algorithms.ThreeLevelScale;
import com.ibm.bi.predict.algorithms.table.AdjustedCountR2;
import com.ibm.bi.predict.algorithms.table.OnewayChiSquareTestForCategTarget;
import com.ibm.bi.predict.data.matrix.Matrix;
import com.ibm.bi.predict.data.matrix.MatrixVectorFactory;
import com.ibm.bi.predict.dataaccess.Decorator;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.sa.execution.annotation.ConditionalResponses;
import com.ibm.bi.predict.sa.execution.annotation.DataRowValidator;
import com.ibm.bi.predict.sa.execution.annotation.FieldRole;
import com.ibm.bi.predict.sa.execution.annotation.MessageServiceImpl;
import com.ibm.bi.predict.sa.execution.annotation.decorations.PredictiveStrengthDecoration;
import com.ibm.bi.predict.sa.execution.annotation.impl.AnnotationImpl;
import com.ibm.bi.predict.sa.execution.annotation.response.ConditionalResponse;
import com.ibm.bi.predict.sa.execution.annotation.result.AnnotationResult;
import com.ibm.bi.predict.sa.execution.annotation.result.PredictiveStrengthResult;
import com.ibm.bi.predict.sa.execution.api.DataRowAdapter;
import com.ibm.bi.predict.sa.execution.api.MetaDataAdapter;
import com.ibm.bi.predict.sa.execution.api.SuggestedAnnotation;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.stream.Collectors;
import org.apache.commons.math3.util.Precision;

public class CategoricalTargetChiSquareAnnotation
extends AnnotationImpl<PredictiveStrengthResult.PredictiveStrengthData>
implements DataRowValidator {
    private static final Logger log = PredictLoggerFactory.getLogger(CategoricalTargetChiSquareAnnotation.class);
    private static final double PREDICTIVE_STRENGTH_THRESHOLD = 0.1;
    private static final int MIN_CATS = 2;
    private static final double ROWCOUNT_THRESHOLD = 2.147483647E9;
    private static final double CHISQUARE_ADJUSTMENT_THRESHOLD = 0.1;
    private static final int MIN_CATEGORY_COUNT_FOR_R2 = 3;
    private static final int RESPONSE_IDX = 0;
    private static final int FACTOR_IDX = 1;
    private static final int COUNT_IDX = 2;
    private final OnewayChiSquareTestForCategTarget chisquareTest = new OnewayChiSquareTestForCategTarget(this.responseCatCount(), this.factorCatCount(), 0, 1, this.metadata.rowCount(), 2, 0.1);
    private final Matrix adjR2Matrix;
    private double pValue = 0.0;
    private double effectSize = 0.0;
    private double adjustedCountR2 = 0.0;
    private ThreeLevelScale associationStrengthLevel;

    public CategoricalTargetChiSquareAnnotation(MetaDataAdapter metaData) {
        super(metaData);
        this.adjR2Matrix = MatrixVectorFactory.makeMatrix((int)this.responseCatCount(), (int)this.factorCatCount(), (int)metaData.rowCount());
    }

    @Override
    public ConditionalResponse assertPreconditions() {
        if (this.factorCatCount() < 2) {
            String label = this.metadata.getNameOfField(this.factorAIndex());
            if (this.isTwoRepeat()) {
                label = label + "|" + this.metadata.getNameOfField(this.factorBIndex());
            }
            return ConditionalResponses.TOO_FEW_CATEGORIES(label);
        }
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public void update(DataRowAdapter dataRow) {
        int responseCatIdx = (int)dataRow.getFieldValueByIndex(this.responseIndex());
        int factorCatIdx = this.getFactorCategory(dataRow);
        double count = dataRow.getFieldValueByIndex(this.getCountIndex());
        double[] record = new double[]{responseCatIdx, factorCatIdx, count};
        this.chisquareTest.update(record, true);
        this.adjR2Matrix.increment(responseCatIdx, factorCatIdx, count);
    }

    @Override
    public ConditionalResponse postUpdate() {
        this.pValue = this.chisquareTest.computeStatistics();
        this.adjustedCountR2 = new AdjustedCountR2(this.adjR2Matrix, 3).getAdjustedCountR2();
        if (this.adjustedCountR2 > 0.1 && (this.pValue <= 0.05 || Precision.equals((double)this.adjustedCountR2, (double)1.0))) {
            this.effectSize = this.chisquareTest.getEffectSize();
            log.debug("Factor is significant - pValue={} effectSize={} adjustedCountRSquared={}", new Object[]{this.pValue, this.effectSize, this.adjustedCountR2});
            this.associationStrengthLevel = CategoricalTargetChiSquareAnnotation.getAssociationStrength(this.adjustedCountR2);
        } else {
            this.effectSize = 0.0;
            this.associationStrengthLevel = null;
        }
        log.debug("Test complete - p-value={} chi-square={} accuracy={} effect-size={}", new Object[]{this.pValue, this.chisquareTest.getChiSquare(), this.chisquareTest.getAccuracy(), this.chisquareTest.getEffectSize()});
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public AnnotationResult<PredictiveStrengthResult.PredictiveStrengthData> buildResult() {
        ArrayList<String> explanatoryNames = new ArrayList<String>();
        explanatoryNames.add(this.metadata.getNameOfField(this.factorAIndex()));
        if (this.isTwoRepeat()) {
            explanatoryNames.add(this.metadata.getNameOfField(this.factorBIndex()));
        }
        return new PredictiveStrengthResult(new PredictiveStrengthResult.PredictiveStrengthData(this.pValue, this.effectSize, this.adjustedCountR2, this.associationStrengthLevel, this.metadata.getNameOfField(this.responseIndex()), explanatoryNames));
    }

    @Override
    public void decorate(Decorator decorator, SuggestedAnnotation annotationSuggestion, Locale locale) {
        new PredictiveStrengthDecoration(new MessageServiceImpl()).decorate(decorator, this.getResult(), locale);
    }

    @Override
    public ConditionalResponse validateDataRow(DataRowAdapter dataRow) {
        double count = dataRow.getFieldValueByIndex(this.getCountIndex());
        if (count > 2.147483647E9) {
            return ConditionalResponses.VALUES_ABOVE_THRESHOLD(this.metadata.getResponseName());
        }
        return ConditionalResponses.SUCCESS;
    }

    private int getFactorCategory(DataRowAdapter dataRow) {
        if (this.isTwoRepeat()) {
            List<Integer> repeats = this.repeatIndices();
            return (int)dataRow.getFieldValueByIndex(repeats.get(0)) * this.repeatMergeMultiplier() + (int)dataRow.getFieldValueByIndex(repeats.get(1));
        }
        return (int)dataRow.getFieldValueByIndex(this.factorAIndex());
    }

    private static ThreeLevelScale getAssociationStrength(double adjRSquared) {
        if (adjRSquared > 0.7) {
            return ThreeLevelScale.HIGH;
        }
        if (adjRSquared > 0.35) {
            return ThreeLevelScale.MEDIUM;
        }
        return ThreeLevelScale.LOW;
    }

    private int getCountIndex() {
        List<Integer> counts = this.fieldsForRole(FieldRole.RESPONSE, FieldType.NUMERICAL);
        if (counts.size() != 1) {
            throw new IllegalArgumentException(String.format("Expected exactly one explanatory numerical field, got %d", counts.size()));
        }
        return counts.get(0);
    }

    private List<Integer> fieldsForRole(FieldRole role, FieldType type) {
        return this.metadata.getFieldIndicesByRole(role).stream().filter(i -> this.metadata.getFieldType((int)i) == type).collect(Collectors.toList());
    }

    private boolean isTwoRepeat() {
        return this.repeatIndices().size() == 2;
    }

    private int repeatMergeMultiplier() {
        return this.metadata.getCountOfFieldCategories(this.repeatIndices().get(1));
    }

    private int factorAIndex() {
        List<Integer> explanatories = this.explanatoryIndices();
        List<Integer> repeats = this.repeatIndices();
        if (explanatories.size() == 2) {
            return explanatories.get(1);
        }
        if (explanatories.size() == 1 && repeats.size() == 1) {
            return repeats.get(0);
        }
        if (explanatories.size() == 1 && repeats.size() == 2) {
            return repeats.get(0);
        }
        throw new IllegalArgumentException("Unknown state for factor A");
    }

    private int factorBIndex() {
        if (this.isTwoRepeat()) {
            return this.repeatIndices().get(1);
        }
        throw new IllegalArgumentException("Unknown state for factor B");
    }

    private int responseIndex() {
        List<Integer> explanatories = this.explanatoryIndices();
        if (explanatories.size() == 2 || explanatories.size() == 1) {
            return explanatories.get(0);
        }
        throw new IllegalArgumentException(String.format("Expected exactly one response categorical field, got %d", explanatories.size()));
    }

    private int responseCatCount() {
        return this.metadata.getCountOfFieldCategories(this.responseIndex());
    }

    private int factorCatCount() {
        List<Integer> explanatories = this.explanatoryIndices();
        List<Integer> repeats = this.repeatIndices();
        if (explanatories.size() == 2) {
            return this.metadata.getCountOfFieldCategories(explanatories.get(1));
        }
        if (explanatories.size() == 1 && repeats.size() == 1) {
            return this.metadata.getCountOfFieldCategories(repeats.get(0));
        }
        if (explanatories.size() == 1 && repeats.size() == 2) {
            return this.metadata.getCountOfFieldCategories(repeats.get(0)) * this.repeatMergeMultiplier();
        }
        throw new IllegalArgumentException("Unknown state");
    }

    private List<Integer> explanatoryIndices() {
        return this.fieldsForRole(FieldRole.EXPLANATORY, FieldType.CATEGORICAL);
    }

    private List<Integer> repeatIndices() {
        return this.fieldsForRole(FieldRole.REPEAT, FieldType.CATEGORICAL);
    }
}

