/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.explore.nlt.deriveddata.operations;

import com.ibm.bi.predict.dataaccess.DataAccessProvider;
import com.ibm.bi.predict.dataaccess.DataRow;
import com.ibm.bi.predict.dataaccess.MetaData;
import com.ibm.bi.predict.dataaccess.types.AggregationType;
import com.ibm.bi.predict.dataaccess.types.StatisticStatus;
import com.ibm.bi.predict.explore.ChartInsightsContext;
import com.ibm.bi.predict.explore.ChartInsightsContextFactory;
import com.ibm.bi.predict.explore.RequestCreator;
import com.ibm.bi.predict.explore.data.DataAccessProviderBuilder;
import com.ibm.bi.predict.explore.nlt.deriveddata.operations.DerivedDataOperation;
import com.ibm.bi.predict.explore.nlt.framework.Operation;
import com.ibm.bi.predict.explore.nlt.framework.Request;
import com.ibm.bi.predict.explore.nlt.math.SumSquares;
import com.ibm.bi.predict.math.NumericUtils;
import com.ibm.bi.predict.types.RoleType;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;

public class AggregateToOneInputOperation
extends DerivedDataOperation {
    private final int response;
    private final int input;
    private final Optional<Integer> weightIndex;
    private final AggregationType aggType;
    private final AggFunc aggFunc;
    private final Optional<OutputFunc> outputFunc;
    private final int numCats;
    private final boolean hasRowCount;
    private final boolean hasSumSqr;
    private final double[] accumulator;
    private final double[] rowCount;
    private final double[] sumSqrs;
    private final SumSquares[] sumSqrsCalcs;
    private final double[] weightAccumulator;
    private static final String ROWCOUNT_LABEL = "rowCount";
    private static final String SUMSQR_LABEL = "sumSqr";

    public AggregateToOneInputOperation(Request request, int response, int input, Optional<Integer> weightIndex, double initialAccumulatorValue, AggregationType aggType, AggFunc aggFunc) {
        this(request, response, input, weightIndex, initialAccumulatorValue, aggType, aggFunc, Optional.empty());
    }

    public AggregateToOneInputOperation(Request request, int response, int input, Optional<Integer> weightIndex, double initialAccumulatorValue, AggregationType aggType, AggFunc aggFunc, OutputFunc outputFunc) {
        this(request, response, input, weightIndex, initialAccumulatorValue, aggType, aggFunc, Optional.of(outputFunc));
    }

    public AggregateToOneInputOperation(Request request, int response, int input, Optional<Integer> weightIndex, double initialAccumulatorValue, AggregationType aggType, AggFunc aggFunc, Optional<OutputFunc> outputFunc) {
        super(request);
        this.response = response;
        this.input = input;
        this.weightIndex = weightIndex;
        this.aggType = aggType;
        this.aggFunc = aggFunc;
        this.outputFunc = outputFunc;
        MetaData metadata = request.get(DataAccessProvider.class).getMetaData();
        this.numCats = metadata.getFieldCategories(input);
        this.hasRowCount = metadata.getStatisticStatus(ROWCOUNT_LABEL, response) == StatisticStatus.AVAILABLE;
        this.hasSumSqr = metadata.getStatisticStatus(SUMSQR_LABEL, response) == StatisticStatus.AVAILABLE && (aggType == AggregationType.SUM || aggType == AggregationType.AVERAGE && this.hasRowCount);
        this.accumulator = new double[this.numCats];
        Arrays.fill(this.accumulator, initialAccumulatorValue);
        this.rowCount = new double[this.numCats];
        this.sumSqrs = new double[this.numCats];
        this.sumSqrsCalcs = new SumSquares[this.numCats];
        if (this.hasSumSqr) {
            for (int i = 0; i < this.numCats; ++i) {
                this.sumSqrsCalcs[i] = new SumSquares();
            }
        }
        if (weightIndex.isPresent()) {
            this.weightAccumulator = new double[this.numCats];
            Arrays.fill(this.weightAccumulator, initialAccumulatorValue);
        } else {
            this.weightAccumulator = null;
        }
    }

    public int getResponseIndex() {
        return this.response;
    }

    public int getInputIndex() {
        return this.input;
    }

    public Optional<Integer> getWeightIndex() {
        return this.weightIndex;
    }

    @Override
    public Operation<Request> update(DataRow row) {
        int cat;
        double val = row.getValue(this.response);
        double dcat = row.getValue(this.input);
        if (!Double.isNaN(val) && !Double.isNaN(dcat) && (cat = (int)dcat) >= 0 && cat < this.numCats) {
            double rc;
            double d = rc = this.hasRowCount ? row.getStatistic(ROWCOUNT_LABEL, this.response) : 1.0;
            if (!Double.isNaN(rc)) {
                this.updateWeight(row, val, cat, rc);
                this.updateSumSqr(row, val, cat, rc);
                this.accumulator[cat] = this.aggFunc.apply(this.accumulator[cat], val, rc);
                int n = cat;
                this.rowCount[n] = this.rowCount[n] + rc;
            }
        }
        return this;
    }

    private void updateWeight(DataRow row, double responseValue, int cat, double rc) {
        this.weightIndex.ifPresent(v -> {
            double wgt = row.getValue(v.intValue());
            if (!Double.isNaN(wgt)) {
                if (this.aggType != AggregationType.MINIMUM && this.aggType != AggregationType.MAXIMUM) {
                    this.weightAccumulator[cat] = this.aggFunc.apply(this.weightAccumulator[cat], wgt, rc);
                } else if (this.shouldAccumulateWeight(responseValue, cat)) {
                    this.weightAccumulator[cat] = wgt;
                }
            }
        });
    }

    private void updateSumSqr(DataRow row, double responseValue, int cat, double rc) {
        double ss;
        if (this.hasSumSqr && rc > 0.0 && !Double.isNaN(ss = row.getStatistic(SUMSQR_LABEL, this.response))) {
            double sum = this.aggType == AggregationType.SUM ? responseValue : responseValue * rc;
            this.sumSqrsCalcs[cat].update(rc, ss, sum);
        }
    }

    @Override
    public Operation<Request> postUpdate() {
        if (this.hasSumSqr) {
            for (int cat = 0; cat < this.numCats; ++cat) {
                this.sumSqrs[cat] = this.sumSqrsCalcs[cat].getSumSquares();
            }
        }
        return this;
    }

    private boolean shouldAccumulateWeight(double response, int cat) {
        return this.aggType == AggregationType.MINIMUM && response < this.accumulator[cat] || this.aggType == AggregationType.MAXIMUM && response > this.accumulator[cat];
    }

    @Override
    public List<Request> getResults() {
        DataAccessProviderBuilder builder = new DataAccessProviderBuilder();
        MetaData metadata = this.request.get(DataAccessProvider.class).getMetaData();
        ChartInsightsContext params = this.request.get(ChartInsightsContext.class);
        String responseId = metadata.getFieldIdentifier(this.response);
        String explanatoryId = metadata.getFieldIdentifier(this.input);
        DataAccessProviderBuilder.FieldBuilder field = builder.addField(responseId).label(metadata.getFieldDisplayLabel(this.response)).type(metadata.getFieldType(this.response)).aggregation(this.aggType).data(this.getAggregatedResponse()).statistic(ROWCOUNT_LABEL, this.rowCount);
        if (this.hasSumSqr) {
            field.statistic(SUMSQR_LABEL, this.sumSqrs);
        }
        builder.addField(explanatoryId).label(metadata.getFieldDisplayLabel(this.input)).type(metadata.getFieldType(this.input)).categories(this.getCategoryLabels(metadata, this.input)).data(this.getInputData());
        HashMap<String, RoleType> roles = new HashMap<String, RoleType>();
        roles.put(responseId, RoleType.RESPONSE);
        roles.put(explanatoryId, this.getInputRole(metadata, params, this.input));
        this.weightIndex.ifPresent(wgt -> {
            String wgtId = metadata.getFieldIdentifier(wgt.intValue());
            builder.addField(wgtId).label(metadata.getFieldDisplayLabel(wgt.intValue())).type(metadata.getFieldType(wgt.intValue())).aggregation(this.aggType).data(this.getAggregatedWeight());
            roles.put(wgtId, RoleType.WEIGHT);
        });
        return Collections.singletonList(RequestCreator.derive(this.request, builder.makeProvider(), ChartInsightsContextFactory.create(roles, params.doSmartAnnotations())));
    }

    private final double[] getAggregatedResponse() {
        if (this.outputFunc.isPresent()) {
            OutputFunc f = this.outputFunc.get();
            double[] result = new double[this.numCats];
            for (int i = 0; i < this.numCats; ++i) {
                result[i] = f.apply(this.accumulator[i], this.rowCount[i]);
            }
            return result;
        }
        return this.accumulator;
    }

    private final double[] getAggregatedWeight() {
        if (this.outputFunc.isPresent()) {
            OutputFunc f = this.outputFunc.get();
            double[] result = new double[this.numCats];
            for (int i = 0; i < this.numCats; ++i) {
                result[i] = f.apply(this.weightAccumulator[i], this.rowCount[i]);
            }
            return result;
        }
        return this.weightAccumulator;
    }

    private final double[] getInputData() {
        double[] result = new double[this.numCats];
        for (int i = 0; i < this.numCats; ++i) {
            result[i] = i;
        }
        return result;
    }

    public static AggregateToOneInputOperation sum(Request request, int response, int input, Optional<Integer> weightIndex) {
        return new AggregateToOneInputOperation(request, response, input, weightIndex, 0.0, request.get(DataAccessProvider.class).getMetaData().getFieldAggregation(response) == AggregationType.COUNT ? AggregationType.COUNT : AggregationType.SUM, (accum, value, rc) -> accum + value);
    }

    public static AggregateToOneInputOperation average(Request request, int response, int index, Optional<Integer> weightIndex) {
        return new AggregateToOneInputOperation(request, response, index, weightIndex, 0.0, AggregationType.AVERAGE, (accum, value, rc) -> accum + value * rc, (accum, rc) -> NumericUtils.equals((double)rc, (double)0.0) ? 0.0 : accum / rc);
    }

    public static AggregateToOneInputOperation min(Request request, int response, int input, Optional<Integer> weightIndex, double defaultValue) {
        return new AggregateToOneInputOperation(request, response, input, weightIndex, Double.POSITIVE_INFINITY, AggregationType.MINIMUM, (accum, value, rc) -> Math.min(accum, value), (accum, rc) -> NumericUtils.equals((double)accum, (double)Double.POSITIVE_INFINITY) ? defaultValue : accum);
    }

    public static AggregateToOneInputOperation max(Request request, int response, int input, Optional<Integer> weightIndex, double defaultValue) {
        return new AggregateToOneInputOperation(request, response, input, weightIndex, Double.NEGATIVE_INFINITY, AggregationType.MAXIMUM, (accum, value, rc) -> Math.max(accum, value), (accum, rc) -> NumericUtils.equals((double)accum, (double)Double.NEGATIVE_INFINITY) ? defaultValue : accum);
    }

    @FunctionalInterface
    public static interface OutputFunc {
        public double apply(double var1, double var3);
    }

    @FunctionalInterface
    public static interface AggFunc {
        public double apply(double var1, double var3, double var5);
    }
}

