/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.sa.execution.annotation.impl;

import com.ibm.bi.predict.algorithms.ThreeLevelScale;
import com.ibm.bi.predict.algorithms.regression.Regression;
import com.ibm.bi.predict.algorithms.regression.RegressionFields;
import com.ibm.bi.predict.algorithms.regression.RegressionSelector;
import com.ibm.bi.predict.dataaccess.Decorator;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.exceptions.InvalidDataException;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.sa.execution.annotation.ConditionalResponses;
import com.ibm.bi.predict.sa.execution.annotation.DataRowValidator;
import com.ibm.bi.predict.sa.execution.annotation.FieldRole;
import com.ibm.bi.predict.sa.execution.annotation.MessageServiceImpl;
import com.ibm.bi.predict.sa.execution.annotation.decorations.FitLineDecoration;
import com.ibm.bi.predict.sa.execution.annotation.impl.AnnotationImpl;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.EquationBuilder;
import com.ibm.bi.predict.sa.execution.annotation.impl.math.StatisticsMap;
import com.ibm.bi.predict.sa.execution.annotation.impl.types.LineEquation;
import com.ibm.bi.predict.sa.execution.annotation.response.ConditionalResponse;
import com.ibm.bi.predict.sa.execution.annotation.result.AnnotationResult;
import com.ibm.bi.predict.sa.execution.annotation.result.EmptyResult;
import com.ibm.bi.predict.sa.execution.annotation.result.FitLineResult;
import com.ibm.bi.predict.sa.execution.api.DataRowAdapter;
import com.ibm.bi.predict.sa.execution.api.MetaDataAdapter;
import com.ibm.bi.predict.sa.execution.api.SuggestedAnnotation;
import com.ibm.bi.predict.sa.execution.utils.RegressionFitUtils;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.math3.util.Precision;

public class FitLineAnnotation
extends AnnotationImpl<List<FitLineResult.FitLineData>>
implements DataRowValidator {
    private static final Logger log = PredictLoggerFactory.getLogger(FitLineAnnotation.class);
    private static final int MAX_REGRESSION_DEGREE = 2;
    private static final int RESTRICTED_MAX_REGRESSION_DEGREE = 1;
    private static final int MAX_GROUP_SIZE = 3;
    private static final int MIN_GROUP_SIZE = 2;
    private static final int MIN_NUM_ROWS = 3;
    private static final int MIN_NUM_ROWS_QUADRATIC = 4;
    private static final double ZERO_THRESHOLD = 1.0E-12;
    private static final int ROUNDING_NUM_DECIMAL_PLACES = 12;
    private final Optional<Integer> groupIndex;
    private final Optional<Integer> weightIndex;
    private final double[] targetValues;
    private final double[] inputValues;
    private final double[] rowCounts;
    private final double[] groups;
    private final double[] weights;
    private Map<Integer, Integer> coded = new HashMap<Integer, Integer>();
    private Map<Integer, String> categoryNames = new HashMap<Integer, String>();
    private Map<Integer, String> categoryIdentifiers = new HashMap<Integer, String>();
    private int codedIndex = 0;
    private ThreeLevelScale associationStrengthLevel;
    private List<Double[]> coefficients;
    private List<LineEquation> equations;
    private double adjustedRSquared;

    public FitLineAnnotation(MetaDataAdapter metadata) {
        super(metadata);
        this.targetValues = new double[metadata.rowCount()];
        this.inputValues = new double[metadata.rowCount()];
        this.rowCounts = new double[metadata.rowCount()];
        this.groupIndex = this.getGroupIndex();
        this.weightIndex = metadata.getIndexByRole(FieldRole.WEIGHT);
        this.groups = this.groupIndex.isPresent() ? new double[metadata.rowCount()] : new double[]{};
        this.weights = this.weightIndex.isPresent() ? new double[metadata.rowCount()] : new double[]{};
    }

    @Override
    public ConditionalResponse assertPreconditions() {
        if (this.metadata.rowCount() < 3) {
            return ConditionalResponses.TOO_FEW_RECORDS;
        }
        return ConditionalResponses.SUCCESS;
    }

    @Override
    public void update(DataRowAdapter dataRow) {
        double targetValue = dataRow.getTargetValue();
        double inputValue = dataRow.getExplanatoryValue();
        double rowCount = dataRow.getTargetStatistic(StatisticsMap.StatisticName.ROW_COUNT);
        this.targetValues[dataRow.getDataRowIndex()] = targetValue;
        this.inputValues[dataRow.getDataRowIndex()] = inputValue;
        this.rowCounts[dataRow.getDataRowIndex()] = rowCount;
        this.groupIndex.ifPresent(i -> {
            this.groups[dataRow.getDataRowIndex()] = this.getCodedGroupIndex(dataRow, (int)i);
        });
        this.weightIndex.ifPresent(i -> {
            this.weights[dataRow.getDataRowIndex()] = dataRow.getFieldValueByIndex((int)i);
        });
    }

    @Override
    public ConditionalResponse postUpdate() {
        RegressionFields f = new RegressionFields(this.targetValues, this.inputValues, this.rowCounts);
        this.groupIndex.ifPresent(i -> f.withGroups(this.groups));
        this.weightIndex.ifPresent(i -> f.withWeights(this.weights));
        try {
            Regression regression = RegressionSelector.findBestFit((RegressionFields)f, (int)this.getMaxRegressionDegree());
            this.coefficients = this.getCoefficients(regression);
            this.equations = this.buildEquations(this.coefficients, regression.degree());
            this.adjustedRSquared = regression.adjustedRSquared();
            this.associationStrengthLevel = FitLineAnnotation.getAssociationStrength(this.adjustedRSquared);
            return ConditionalResponses.SUCCESS;
        }
        catch (InvalidDataException e) {
            log.error("Attempted to build regression model with invalid data", (Throwable)e);
            return ConditionalResponses.INVALID_FIELD_UNSPECIFIED();
        }
    }

    @Override
    public AnnotationResult<List<FitLineResult.FitLineData>> buildResult() {
        if (this.coefficients == null) {
            return new EmptyResult<List<FitLineResult.FitLineData>>();
        }
        ArrayList<FitLineResult.FitLineData> fitLineResults = new ArrayList<FitLineResult.FitLineData>();
        for (int i = 0; i < this.coefficients.size(); ++i) {
            if (this.groupIndex.isPresent() && this.coefficients.size() > 1) {
                String groupName = this.metadata.getNameOfField(this.groupIndex.get());
                fitLineResults.add(new FitLineResult.FitLineData(ArrayUtils.toPrimitive((Double[])this.coefficients.get(i)), this.equations.get(i), this.adjustedRSquared, this.categoryNames.get(i), this.categoryIdentifiers.get(i), groupName, this.associationStrengthLevel, this.metadata.getResponseName(), this.metadata.getNameOfExplanatoryField(0)));
                continue;
            }
            fitLineResults.add(new FitLineResult.FitLineData(ArrayUtils.toPrimitive((Double[])this.coefficients.get(i)), this.equations.get(i), this.adjustedRSquared, this.associationStrengthLevel, this.metadata.getResponseName(), this.metadata.getNameOfExplanatoryField(0)));
        }
        return new FitLineResult(fitLineResults);
    }

    @Override
    public void decorate(Decorator decorator, SuggestedAnnotation annotationSuggestion, Locale locale) {
        AnnotationResult<List<FitLineResult.FitLineData>> result = this.getResult();
        if (!result.isEmpty()) {
            int responseIndex = annotationSuggestion.indexOfTarget;
            new FitLineDecoration(responseIndex, new MessageServiceImpl()).decorate(decorator, result, locale);
        }
    }

    @Override
    public ConditionalResponse validateDataRow(DataRowAdapter dataRow) {
        return RegressionFitUtils.validateDataRow(dataRow, this.weightIndex, this.metadata);
    }

    private int getMaxRegressionDegree() {
        return this.metadata.rowCount() >= 4 ? 2 : 1;
    }

    private Optional<Integer> getGroupIndex() {
        return this.metadata.getIndexByRole(FieldRole.GROUP).filter(this::groupIsCategorical).filter(this::validNumOfCategories);
    }

    private List<Double[]> getCoefficients(Regression regression) {
        double[] solutionCoefficients = FitLineAnnotation.checkNearZeroness(regression.solve());
        if (this.groupIndex.isPresent() && regression.numberOfParameters() > regression.degree()) {
            return this.formatCoefficients(solutionCoefficients, regression.degree());
        }
        return Arrays.asList(new Double[][]{ArrayUtils.toObject((double[])solutionCoefficients)});
    }

    private static double[] checkNearZeroness(double[] coefficients) {
        double[] formatted = new double[coefficients.length];
        for (int i = 0; i < coefficients.length; ++i) {
            formatted[i] = NumericUtils.equals((double)coefficients[i], (double)0.0, (double)1.0E-12) ? 0.0 : Precision.round((double)coefficients[i], (int)12, (int)4);
        }
        return formatted;
    }

    private double getCodedGroupIndex(DataRowAdapter dataRow, int groupIndex) {
        int index = (int)dataRow.getFieldValueByIndex(groupIndex);
        if (!this.coded.containsKey(index)) {
            this.coded.put(index, this.codedIndex);
            this.categoryNames.put(this.codedIndex, this.metadata.getCategoryNameForField(groupIndex, index));
            this.categoryIdentifiers.put(this.codedIndex, this.metadata.getCategoryIdentifierForField(groupIndex, index));
            ++this.codedIndex;
        }
        return this.coded.get(index).intValue();
    }

    private List<Double[]> formatCoefficients(double[] coefficients, int degree) {
        List<Double[]> separatedCoefficients = this.separateCoefficients(coefficients, degree);
        int numCoefficients = degree + 1;
        return separatedCoefficients.stream().map(coeff -> {
            int i;
            double[] c = new double[numCoefficients];
            for (i = 0; i < ((Double[])coeff).length; ++i) {
                int n = i % numCoefficients;
                c[n] = c[n] + coeff[i];
            }
            for (i = 0; i < c.length; ++i) {
                c[i] = Math.abs(c[i]) < 1.0E-12 ? 0.0 : Precision.round((double)c[i], (int)12, (int)4);
            }
            return ArrayUtils.toObject((double[])c);
        }).collect(Collectors.toList());
    }

    private List<Double[]> separateCoefficients(double[] coefficients, int degree) {
        ArrayList<Double[]> separatedCoefficients = new ArrayList<Double[]>();
        int numFactorCategories = this.coded.size();
        for (int i = 0; i < numFactorCategories; ++i) {
            int j;
            ArrayList<Double> c = new ArrayList<Double>();
            for (j = 0; j < degree + 1; ++j) {
                c.add(coefficients[j]);
            }
            if (numFactorCategories == 2) {
                this.addCoefficientForSingleDummyVariable(coefficients, i, c, j);
            } else {
                this.addCoefficientsForMultipleDummyVariables(coefficients, numFactorCategories, i, c, j);
            }
            separatedCoefficients.add(c.toArray(new Double[0]));
        }
        return separatedCoefficients;
    }

    private void addCoefficientForSingleDummyVariable(double[] coefficients, int i, List<Double> c, int j) {
        for (int x = j; x < coefficients.length; ++x) {
            if (i <= 0) continue;
            c.add(coefficients[x]);
        }
    }

    private void addCoefficientsForMultipleDummyVariables(double[] coefficients, int numFactorCategories, int i, List<Double> c, int j) {
        for (int x = j + (i - 1); x < coefficients.length; x += numFactorCategories - 1) {
            if (i <= 0) continue;
            c.add(coefficients[x]);
        }
    }

    private List<LineEquation> buildEquations(List<Double[]> coefficients, int degree) {
        return coefficients.stream().map(c -> this.buildEquation((Double[])c, degree)).collect(Collectors.toList());
    }

    private LineEquation buildEquation(Double[] coefficients, int degree) {
        return EquationBuilder.buildEquation(this.metadata.getResponseName(), this.metadata.getNameOfExplanatoryField(0), ArrayUtils.toPrimitive((Double[])coefficients), degree);
    }

    private boolean groupIsCategorical(int groupIndex) {
        return this.metadata.getFieldType(groupIndex) == FieldType.CATEGORICAL;
    }

    private boolean validNumOfCategories(int groupIndex) {
        return this.metadata.getCountOfCategoriesByRole(FieldRole.GROUP).filter(numCategories -> numCategories <= 3 && numCategories >= 2).isPresent();
    }

    private static ThreeLevelScale getAssociationStrength(double adjRSquared) {
        if (adjRSquared <= 0.35) {
            return ThreeLevelScale.LOW;
        }
        if (adjRSquared <= 0.7) {
            return ThreeLevelScale.MEDIUM;
        }
        return ThreeLevelScale.HIGH;
    }
}

