/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl;

import com.google.common.collect.Sets;
import com.ibm.bi.predict.algorithms.regression.Outlier;
import com.ibm.bi.predict.algorithms.regression.Regression;
import com.ibm.bi.predict.algorithms.regression.RegressionFields;
import com.ibm.bi.predict.algorithms.regression.RegressionSelector;
import com.ibm.bi.predict.algorithms.regression.StudentizedResiduals;
import com.ibm.bi.predict.algorithms.summaries.Means;
import com.ibm.bi.predict.dataaccess.Decorator;
import com.ibm.bi.predict.exceptions.InvalidDataException;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.math.TopNSelector;
import com.ibm.bi.predict.sa.execution.annotation.ConditionalResponses;
import com.ibm.bi.predict.sa.execution.annotation.DataRowValidator;
import com.ibm.bi.predict.sa.execution.annotation.FieldRole;
import com.ibm.bi.predict.sa.execution.annotation.MessageServiceImpl;
import com.ibm.bi.predict.sa.execution.annotation.decorations.OutlierDecoration;
import com.ibm.bi.predict.sa.execution.annotation.impl.AnnotationImpl;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.StatisticsMap;
import com.ibm.bi.predict.sa.execution.annotation.response.ConditionalResponse;
import com.ibm.bi.predict.sa.execution.annotation.result.AnnotationResult;
import com.ibm.bi.predict.sa.execution.annotation.result.EmptyResult;
import com.ibm.bi.predict.sa.execution.annotation.result.OutlierResult;
import com.ibm.bi.predict.sa.execution.api.DataRowAdapter;
import com.ibm.bi.predict.sa.execution.api.MetaDataAdapter;
import com.ibm.bi.predict.sa.execution.api.SuggestedAnnotation;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class StudentizedResidualsAnnotation
extends AnnotationImpl<List<OutlierResult.OutlierData>>
implements DataRowValidator {
    private static final Logger log = PredictLoggerFactory.getLogger(StudentizedResidualsAnnotation.class);
    private static final double OUTLIER_THRESHOLD = 3.0;
    private static final double UNIVARIATE_OUTLIER_THRESHOLD = 10.0;
    private static final int MAX_REGRESSION_DEGREE = 2;
    private static final int RESTRICTED_MAX_REGRESSION_DEGREE = 1;
    private static final int MIN_PTS = 3;
    private static final int MIN_NUM_ROWS_QUADRATIC = 4;
    private List<Double> targetValues = new ArrayList<Double>();
    private List<Double> inputValues = new ArrayList<Double>();
    private List<Double> counts = new ArrayList<Double>();
    private List<Double> weights = new ArrayList<Double>();
    private List<Outlier> outliers = new ArrayList<Outlier>();
    private boolean hasWeight = false;
    private List<Integer> indexes = new ArrayList<Integer>();
    private int weightIndex = -1;

    public StudentizedResidualsAnnotation(MetaDataAdapter metaData) {
        super(metaData);
        Optional<Integer> maybeWeightIndex = this.metadata.getIndexByRole(FieldRole.WEIGHT);
        if (maybeWeightIndex.isPresent()) {
            this.hasWeight = true;
            this.weightIndex = maybeWeightIndex.get();
        }
    }

    @Override
    public void update(DataRowAdapter dataRow) {
        double targetValue = dataRow.getTargetValue();
        double inputValue = dataRow.getExplanatoryValue();
        double count = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        this.targetValues.add(targetValue);
        this.inputValues.add(inputValue);
        this.counts.add(count);
        this.indexes.add(dataRow.getDataRowIndex());
        if (this.hasWeight) {
            double weight = dataRow.getFieldValueByIndex(this.weightIndex);
            this.weights.add(weight);
        }
    }

    @Override
    public ConditionalResponse assertPreconditions() {
        if (this.metadata.rowCount() < 3) {
            return ConditionalResponses.TOO_FEW_RECORDS;
        }
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public ConditionalResponse postUpdate() {
        log.debug("Beginning postUpdate of studentized residuals annotation - numRows={}", (Object)this.targetValues.size());
        Optional<List<Outlier>> detectedOutliers = this.detectOutliers();
        if (detectedOutliers.isPresent()) {
            this.outliers = detectedOutliers.get();
            log.debug("Completed outlier detection - numOutliers={}", (Object)this.outliers.size());
            return ConditionalResponses.SUCCESS;
        }
        return ConditionalResponses.INVALID_FIELD_UNSPECIFIED();
    }

    @Override
    public AnnotationResult<List<OutlierResult.OutlierData>> buildResult() {
        if (this.outliers.isEmpty()) {
            return new EmptyResult<List<OutlierResult.OutlierData>>();
        }
        List<OutlierResult.OutlierData> outlierData = this.outliers.stream().map(o -> new OutlierResult.OutlierData(this.getDataRowIndex(o.index), null, o.residualVal, o.responseVal, o.explanatoryVal, this.metadata.getResponseName(), this.metadata.getNameOfFirstExplanatoryField(), this.metadata.getAggregationTypeOfField(this.metadata.getResponseIndex()))).collect(Collectors.toList());
        return new OutlierResult(outlierData);
    }

    @Override
    public void decorate(Decorator decorator, SuggestedAnnotation annotationSuggestion, Locale locale) {
        AnnotationResult<List<OutlierResult.OutlierData>> result = this.getResult();
        if (!result.isEmpty()) {
            int responseIndex = annotationSuggestion.indexOfTarget;
            new OutlierDecoration(responseIndex, new MessageServiceImpl()).decorate(decorator, result, locale);
        }
    }

    @Override
    public ConditionalResponse validateDataRow(DataRowAdapter dataRow) {
        if (!this.hasWeight) {
            return ConditionalResponses.SUCCESS;
        }
        double weightValue = dataRow.getFieldValueByIndex(this.weightIndex);
        if (weightValue <= 0.0) {
            return ConditionalResponses.INVALID_DATA_RECORD_WARNING(this.metadata.getNameOfField(this.weightIndex));
        }
        return ConditionalResponses.SUCCESS;
    }

    private Optional<List<Outlier>> detectOutliers() {
        Set<Integer> outlierIndexes = this.detectUnivariateOutliers();
        List<Double> filteredTargetValues = StudentizedResidualsAnnotation.deleteOutlierRows(this.targetValues, outlierIndexes);
        List<Double> filteredInputValues = StudentizedResidualsAnnotation.deleteOutlierRows(this.inputValues, outlierIndexes);
        List<Double> filteredCounts = StudentizedResidualsAnnotation.deleteOutlierRows(this.counts, outlierIndexes);
        List<Double> filteredWeights = Collections.emptyList();
        if (this.hasWeight) {
            filteredWeights = StudentizedResidualsAnnotation.deleteOutlierRows(this.weights, outlierIndexes);
        }
        try {
            Regression regression = this.buildRegressionModel(filteredTargetValues, filteredInputValues, filteredCounts, filteredWeights);
            return Optional.of(this.filter(this.detectOutliersFromModel(regression)));
        }
        catch (InvalidDataException e) {
            log.error("Attempted to build regression model with invalid data", (Throwable)e);
            return Optional.empty();
        }
    }

    private Set<Integer> detectUnivariateOutliers() {
        return Sets.union(this.detectUnivariateOutliers(this.targetValues, this.counts), this.detectUnivariateOutliers(this.inputValues, this.counts));
    }

    private static List<Double> deleteOutlierRows(List<Double> values, Set<Integer> outlierIndexes) {
        return IntStream.range(0, values.size()).filter((int i) -> !outlierIndexes.contains(i)).mapToObj(values::get).collect(Collectors.toList());
    }

    private Regression buildRegressionModel(List<Double> targetValues, List<Double> inputValues, List<Double> counts, List<Double> weights) {
        RegressionFields f = new RegressionFields(NumericUtils.listToDoubleArray(targetValues), NumericUtils.listToDoubleArray(inputValues), NumericUtils.listToDoubleArray(counts));
        if (this.hasWeight) {
            f.withWeights(NumericUtils.listToDoubleArray(weights));
        }
        return RegressionSelector.findBestFit((RegressionFields)f, (int)this.getMaxRegressionDegree());
    }

    private List<Outlier> detectOutliersFromModel(Regression regression) {
        StudentizedResiduals s = new StudentizedResiduals(regression, 3.0).fit(NumericUtils.listToDoubleArray(this.targetValues), NumericUtils.listToDoubleArray(this.inputValues), NumericUtils.listToDoubleArray(this.weights));
        return s.get();
    }

    private Set<Integer> detectUnivariateOutliers(List<Double> values, List<Double> weights) {
        HashSet<Integer> univariateOutliers = new HashSet<Integer>();
        double mean = Means.weightedMean((double[])NumericUtils.listToDoubleArray(values), (double[])NumericUtils.listToDoubleArray(this.counts));
        double total = 0.0;
        double totalCount = 0.0;
        for (int i = 0; i < values.size(); ++i) {
            total += this.counts.get(i) * Math.pow(values.get(i) - mean, 2.0);
            totalCount += this.counts.get(i).doubleValue();
        }
        double stdDev = Math.sqrt(total / (totalCount - 1.0));
        double lowerBound = mean - 10.0 * stdDev;
        double upperBound = mean + 10.0 * stdDev;
        for (int i = 0; i < values.size(); ++i) {
            double value = values.get(i);
            if (!(value < lowerBound) && !(value > upperBound)) continue;
            univariateOutliers.add(i);
        }
        return univariateOutliers;
    }

    private int getDataRowIndex(int outlierIndex) {
        return this.indexes.get(outlierIndex);
    }

    private List<Outlier> filter(List<Outlier> outliers) {
        Collections.sort(outliers, (c1, c2) -> Double.compare(Math.abs(c2.residualVal), Math.abs(c1.residualVal)));
        double[] resValues = outliers.stream().mapToDouble(c -> Math.abs(c.residualVal)).toArray();
        int topN = TopNSelector.selectTopN((double[])resValues, (int)this.targetValues.size());
        return outliers.subList(0, topN);
    }

    private int getMaxRegressionDegree() {
        return this.metadata.rowCount() >= 4 ? 2 : 1;
    }
}

