/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.smarts.prescriptive.recommender.ml;

import com.ibm.smarts.db.query.util.RectangleOfData;
import com.ibm.smarts.prescriptive.recommender.internal.PrescriptiveException;
import com.ibm.smarts.prescriptive.recommender.internal.PrescriptiveStatus;
import com.ibm.smarts.prescriptive.recommender.ml.Adam;
import com.ibm.smarts.prescriptive.recommender.ml.ColStatistics;
import com.ibm.smarts.prescriptive.recommender.ml.DataExtractor;
import com.ibm.smarts.prescriptive.recommender.ml.IRegressor;
import com.ibm.smarts.prescriptive.recommender.ml.NnRegressorConfig;
import com.ibm.smarts.prescriptive.recommender.ml.NumJa;
import com.ibm.smarts.schema.ColumnInfo;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NnRegressor
implements IRegressor {
    static final Logger LOGGER = LoggerFactory.getLogger(NnRegressor.class);
    private NnRegressorConfig config;
    private int inputSize;
    private int outputSize;
    private List<ColStatistics> colStats;
    private int tragetColumnIndex;
    private List<Integer> featureColumnsIndexes;
    private Map<Integer, ColumnInfo> colInfoMap = Collections.emptyMap();
    private DataExtractor dataExtractor;
    private double[][] w1;
    private double[][] b1;
    private double[][] w2;
    private double[][] b2;
    private Adam w1Adam;
    private Adam b1Adam;
    private Adam w2Adam;
    private Adam b2Adam;
    private boolean modelIsTrained = false;

    public NnRegressor(NnRegressorConfig config) {
        this.config = config;
        this.w1Adam = new Adam(config.getLearningRate());
        this.b1Adam = new Adam(config.getLearningRate());
        this.w2Adam = new Adam(config.getLearningRate());
        this.b2Adam = new Adam(config.getLearningRate());
    }

    @Override
    public void fit(RectangleOfData rectangleOfData, int tragetColumnIndex, List<Integer> featureColumnsIndexes) {
        this.dataExtractor = new DataExtractor(rectangleOfData, tragetColumnIndex, featureColumnsIndexes);
        this.colStats = this.dataExtractor.getColStatistics();
        this.tragetColumnIndex = tragetColumnIndex;
        this.featureColumnsIndexes = featureColumnsIndexes;
        this.colInfoMap = rectangleOfData.getColumnInfoMap();
        this.fit();
        this.modelIsTrained = true;
    }

    @Override
    public double predict(double[][] rowOfData) throws PrescriptiveException {
        if (!this.modelIsTrained) {
            throw new PrescriptiveException(PrescriptiveStatus.MODEL_IS_NOT_TRAINED);
        }
        double[][] xNormalized = this.normalizeInput(rowOfData);
        double[][] z1 = NumJa.addDot(this.w1, xNormalized, this.b1);
        double[][] a1 = NumJa.tanh(z1);
        double[][] z2 = NumJa.addDot(this.w2, a1, this.b2);
        return this.dataExtractor.denormalize(z2[0][0], 0);
    }

    @Override
    public int getInputSize() {
        return this.inputSize;
    }

    @Override
    public int getOutputSize() {
        return this.outputSize;
    }

    @Override
    public List<ColStatistics> getInputColStatistics() {
        return this.colStats;
    }

    private void fit() {
        double[][] x = this.dataExtractor.getX();
        double[][] y = this.dataExtractor.getY();
        this.inputSize = x.length;
        this.outputSize = y.length;
        this.initialize(this.inputSize);
        long startTime = System.currentTimeMillis();
        this.train(x, y);
        LOGGER.info("Training took {} ms.", (Object)(System.currentTimeMillis() - startTime));
    }

    private void train(double[][] x, double[][] y) {
        int rowNo = x[0].length;
        double learningRate = this.config.getLearningRate();
        double cost = 0.0;
        double[][] estimatedY = new double[y.length][rowNo];
        for (int i = 1; i <= this.config.getEpochNo(); ++i) {
            double[][] z1 = NumJa.addDot(this.w1, x, this.b1);
            double[][] a1 = NumJa.tanh(z1);
            estimatedY = NumJa.addDot(this.w2, a1, this.b2);
            cost = NumJa.mae(rowNo, y, estimatedY);
            double[][] dZ2 = NumJa.subtract(estimatedY, y);
            double[][] dW2 = NumJa.divide(NumJa.dotT(dZ2, a1), rowNo);
            double[][] db2 = NumJa.sumDivide(dZ2, rowNo);
            double[][] dZ1 = NumJa.multiply(NumJa.tDot(this.w2, dZ2), NumJa.subtract(1.0, NumJa.power(a1, 2)));
            double[][] dW1 = NumJa.divide(NumJa.dotT(dZ1, x), rowNo);
            double[][] db1 = NumJa.sumDivide(dZ1, rowNo);
            this.update(i, dW2, db2, dW1, db1);
            if (i % 1000 != 0) continue;
            LOGGER.debug("iteration number {}, mse error {}", (Object)i, (Object)cost);
        }
        LOGGER.info("normlized mae error is {}", (Object)cost);
        double estimationError = 0.0;
        double estimatedValue = 0.0;
        double actualValue = 0.0;
        double error = 0.0;
        for (int i = 0; i < rowNo; ++i) {
            estimatedValue = this.dataExtractor.denormalize(estimatedY[0][i], 0);
            if (!((error += Math.abs(estimatedValue - (actualValue = this.dataExtractor.denormalize(y[0][i], 0))) / actualValue) > 0.0)) continue;
            estimationError += error;
        }
        LOGGER.info("Average estimation error is between \u00b1 {}", (Object)(estimationError / (double)rowNo));
    }

    private void update(int t, double[][] dW2, double[][] db2, double[][] dW1, double[][] db1) {
        this.w1 = this.w1Adam.update(dW1, this.w1, t);
        this.b1 = this.b1Adam.update(db1, this.b1, t);
        this.w2 = this.w2Adam.update(dW2, this.w2, t);
        this.b2 = this.b2Adam.update(db2, this.b2, t);
    }

    private void initialize(int inputSize) {
        this.w1 = NumJa.simiRandom(this.config.getNnSize(), inputSize, 0);
        this.b1 = new double[this.config.getNnSize()][1];
        this.w2 = NumJa.simiRandom(1, this.config.getNnSize(), 17);
        this.b2 = new double[1][1];
    }

    public double[][] normalizeInput(double[][] a) throws PrescriptiveException {
        if (!this.modelIsTrained) {
            throw new PrescriptiveException(PrescriptiveStatus.MODEL_IS_NOT_TRAINED);
        }
        int colNo = a.length;
        int rowNo = a[0].length;
        double[][] b = new double[colNo][rowNo];
        for (int c = 0; c < colNo; ++c) {
            for (int r = 0; r < rowNo; ++r) {
                b[c][r] = this.dataExtractor.normalize(a[c][r], c + 1);
            }
        }
        return b;
    }

    @Override
    public int getTragetColumnIndex() {
        return this.tragetColumnIndex;
    }

    @Override
    public List<Integer> getFeatureColumnsIndexes() {
        return this.featureColumnsIndexes;
    }

    public Map<Integer, ColumnInfo> getColInfoMap() {
        return this.colInfoMap;
    }

    public void setColInfoMap(Map<Integer, ColumnInfo> colInfoMap) {
        this.colInfoMap = colInfoMap;
    }

    @Override
    public String getColId(int index) {
        return this.colInfoMap.get(index).getPrettyName();
    }
}

