/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl;

import com.ibm.bi.predict.algorithms.ThreeLevelScale;
import com.ibm.bi.predict.algorithms.regression.Regression;
import com.ibm.bi.predict.algorithms.regression.RegressionFields;
import com.ibm.bi.predict.algorithms.regression.RegressionSelector;
import com.ibm.bi.predict.dataaccess.Decorator;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.exceptions.InvalidDataException;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.sa.execution.annotation.ConditionalResponses;
import com.ibm.bi.predict.sa.execution.annotation.DataRowValidator;
import com.ibm.bi.predict.sa.execution.annotation.FieldRole;
import com.ibm.bi.predict.sa.execution.annotation.MessageServiceImpl;
import com.ibm.bi.predict.sa.execution.annotation.decorations.PredictiveStrengthDecoration;
import com.ibm.bi.predict.sa.execution.annotation.impl.AnnotationImpl;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.StatisticsMap;
import com.ibm.bi.predict.sa.execution.annotation.response.ConditionalResponse;
import com.ibm.bi.predict.sa.execution.annotation.result.AnnotationResult;
import com.ibm.bi.predict.sa.execution.annotation.result.PredictiveStrengthResult;
import com.ibm.bi.predict.sa.execution.api.DataRowAdapter;
import com.ibm.bi.predict.sa.execution.api.MetaDataAdapter;
import com.ibm.bi.predict.sa.execution.api.SuggestedAnnotation;
import com.ibm.bi.predict.sa.execution.utils.RegressionFitUtils;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import com.spss.math.statistics.DistributionFunctions;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import org.slf4j.Logger;

public class FitRatioAnnotation
extends AnnotationImpl<PredictiveStrengthResult.PredictiveStrengthData>
implements DataRowValidator {
    private static final Logger log = PredictLoggerFactory.getLogger(FitRatioAnnotation.class);
    private static final double MINIMUM_R2 = 0.1;
    private static final int MAX_REGRESSION_DEGREE = 2;
    private static final int RESTRICTED_MAX_REGRESSION_DEGREE = 1;
    private static final int MIN_NUM_ROWS = 3;
    private static final int MIN_NUM_ROWS_QUADRATIC = 4;
    private static final int MIN_GROUP_SIZE = 2;
    private static final int MAX_GROUP_SIZE = 3;
    private final Optional<Integer> groupIndex;
    private final Optional<Integer> weightIndex;
    private final double[] targetValues;
    private final double[] inputValues;
    private final double[] rowCounts;
    private final double[] groups;
    private final double[] weights;
    private Map<Integer, Integer> coded = new HashMap<Integer, Integer>();
    private int codedIndex = 0;
    private ThreeLevelScale associationStrengthLevel;
    private double pValue;
    private double effectSize;
    private double adjustedRSquared;

    public FitRatioAnnotation(MetaDataAdapter metaData) {
        super(metaData);
        this.targetValues = new double[this.metadata.rowCount()];
        this.inputValues = new double[this.metadata.rowCount()];
        this.rowCounts = new double[this.metadata.rowCount()];
        this.groupIndex = this.getGroupIndex();
        this.weightIndex = this.metadata.getIndexByRole(FieldRole.WEIGHT);
        this.groups = this.groupIndex.isPresent() ? new double[this.metadata.rowCount()] : new double[]{};
        this.weights = this.weightIndex.isPresent() ? new double[this.metadata.rowCount()] : new double[]{};
    }

    @Override
    public ConditionalResponse assertPreconditions() {
        if (this.metadata.rowCount() < 3) {
            return ConditionalResponses.TOO_FEW_RECORDS;
        }
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public void update(DataRowAdapter dataRow) {
        double targetValue = dataRow.getTargetValue();
        double inputValue = dataRow.getExplanatoryValue();
        double rowCount = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        this.targetValues[dataRow.getDataRowIndex()] = targetValue;
        this.inputValues[dataRow.getDataRowIndex()] = inputValue;
        this.rowCounts[dataRow.getDataRowIndex()] = rowCount;
        this.groupIndex.ifPresent(i -> {
            this.groups[dataRow.getDataRowIndex()] = this.getCodedGroupIndex(dataRow, (int)i);
        });
        this.weightIndex.ifPresent(i -> {
            this.weights[dataRow.getDataRowIndex()] = dataRow.getFieldValueByIndex((int)i);
        });
    }

    @Override
    public ConditionalResponse postUpdate() {
        RegressionFields f = new RegressionFields(this.targetValues, this.inputValues, this.rowCounts);
        this.groupIndex.ifPresent(i -> f.withGroups(this.groups));
        this.weightIndex.ifPresent(i -> f.withWeights(this.weights));
        try {
            Regression regression = RegressionSelector.findBestFit((RegressionFields)f, (int)this.getMaxRegressionDegree());
            this.pValue = this.computePValue(regression);
            this.adjustedRSquared = regression.adjustedRSquared();
            if (!NumericUtils.isMissingValue((double)this.pValue) && this.pValue <= 0.05 && this.adjustedRSquared > 0.1) {
                this.effectSize = regression.etaSquared();
                this.associationStrengthLevel = this.getAssociationStrength(this.adjustedRSquared);
                log.debug("Found significant predictive strength relationship - pValue={} effectSize={}", (Object)this.pValue, (Object)this.effectSize);
            } else {
                log.debug("Predictive strength relationship was not significant - pValue={}", (Object)this.pValue);
            }
            return ConditionalResponses.SUCCESS;
        }
        catch (InvalidDataException e) {
            log.error("Attempted to build regression model with invalid data", (Throwable)e);
            return ConditionalResponses.INVALID_FIELD_UNSPECIFIED();
        }
    }

    @Override
    public AnnotationResult<PredictiveStrengthResult.PredictiveStrengthData> buildResult() {
        return new PredictiveStrengthResult(new PredictiveStrengthResult.PredictiveStrengthData(this.pValue, this.effectSize, this.adjustedRSquared, this.associationStrengthLevel, this.metadata.getResponseName(), this.metadata.getNameOfFirstExplanatoryField()));
    }

    @Override
    public void decorate(Decorator decorator, SuggestedAnnotation annotationSuggestion, Locale locale) {
        new PredictiveStrengthDecoration(new MessageServiceImpl()).decorate(decorator, this.getResult(), locale);
    }

    @Override
    public ConditionalResponse validateDataRow(DataRowAdapter dataRow) {
        return RegressionFitUtils.validateDataRow(dataRow, this.weightIndex, this.metadata);
    }

    private int getMaxRegressionDegree() {
        return this.metadata.rowCount() >= 4 ? 2 : 1;
    }

    private double computePValue(Regression regression) {
        if (regression.getErrorSumSquares() < 1.0E-12 || regression.getNumOfPredictors() <= 0) {
            if (regression.getRegressionSumSquares() >= 1.0E-12) {
                return 0.0;
            }
            return Double.NaN;
        }
        double meanSquareRegression = regression.getMeanSquaredRegression();
        double meanSquareError = regression.getMeanSquaredError();
        double fStat = meanSquareRegression / meanSquareError;
        double dfR = regression.getDegreesOfFreedomRegression();
        double dfE = regression.getDegreesOfFreedomError();
        double cdfF = Double.isNaN(fStat) ? 0.0 : DistributionFunctions.cdfF((double)fStat, (double)dfR, (double)dfE);
        this.pValue = 1.0 - cdfF;
        log.debug("Completed determination of predictive strength - pValue={} meanSquareRegression={} meanSquareError={} adjustedR2={}", new Object[]{this.pValue, meanSquareRegression, meanSquareError, regression.adjustedRSquared()});
        return this.pValue;
    }

    private ThreeLevelScale getAssociationStrength(double adjRSquared) {
        if (adjRSquared <= 0.35) {
            return ThreeLevelScale.LOW;
        }
        if (adjRSquared <= 0.7) {
            return ThreeLevelScale.MEDIUM;
        }
        return ThreeLevelScale.HIGH;
    }

    private Optional<Integer> getGroupIndex() {
        return this.metadata.getIndexByRole(FieldRole.GROUP).filter(this::groupIsCategorical).filter(this::validNumOfCategories);
    }

    private boolean groupIsCategorical(int groupIndex) {
        return this.metadata.getFieldType(groupIndex) == FieldType.CATEGORICAL;
    }

    private boolean validNumOfCategories(int groupIndex) {
        return this.metadata.getCountOfCategoriesByRole(FieldRole.GROUP).filter(numCategories -> numCategories <= 3 && numCategories >= 2).isPresent();
    }

    private double getCodedGroupIndex(DataRowAdapter dataRow, int groupIndex) {
        int index = (int)dataRow.getFieldValueByIndex(groupIndex);
        if (!this.coded.containsKey(index)) {
            this.coded.put(index, this.codedIndex);
            ++this.codedIndex;
        }
        return this.coded.get(index).intValue();
    }
}

