/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.regression;

import com.ibm.bi.predict.algorithms.MutatingCellVisitor;
import com.ibm.bi.predict.exceptions.InvalidDataException;
import com.ibm.bi.predict.math.NumericUtils;
import java.util.Arrays;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.DecompositionSolver;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
import org.apache.commons.math3.linear.RealVector;

public class Regression {
    private final double[] targetValues;
    private final double[][] regressorMatrix;
    private double[] freqWeights;
    private double[] regWeights;
    private double[] coefficients;
    private double regressionSumSqr = 0.0;
    private double errorSumSqr = 0.0;
    private double totalSumSqr = 0.0;
    private double totalWeight;
    private double numRecordsWeighted = 0.0;
    private boolean isFit = false;
    private int degree = 0;
    private Array2DRowRealMatrix matrix;

    public Regression(double[] targetValues, double[][] inputValues, double[] frequencyWeights, double[] regressionWeights, int degree) {
        this.targetValues = targetValues;
        this.degree = degree;
        this.freqWeights = frequencyWeights;
        this.regWeights = regressionWeights;
        this.totalWeight = this.totalWeight();
        if (NumericUtils.isZero((double)this.totalWeight)) {
            throw new InvalidDataException("Total computed weight is zero");
        }
        this.regressorMatrix = Regression.constructRegressorMatrix(targetValues, inputValues);
    }

    public double[] solve() {
        this.matrix = this.makeWeightedDesignMatrix();
        RealVector responseVector = this.makeWeightedResponseVector();
        DecompositionSolver solver = new QRDecomposition((RealMatrix)this.matrix).getSolver();
        RealVector solution = solver.solve(responseVector);
        this.coefficients = solution.toArray();
        this.calculateStats(this.matrix);
        this.isFit = true;
        return this.getCoefficients();
    }

    public int degree() {
        return this.degree;
    }

    public int numberOfParameters() {
        return this.regressorMatrix[0].length - 1;
    }

    public double rSquared() {
        return 1.0 - this.errorSumSqr / this.getTotalSumSquares();
    }

    public double adjustedRSquared() {
        double rSquared = this.rSquared();
        double degFreedomError = this.getNumRecordsWeighted() - (double)this.getNumOfPredictors() - 1.0;
        if (degFreedomError <= 0.0) {
            return 1.0;
        }
        return 1.0 - (1.0 - rSquared) * (this.numRecordsWeighted - 1.0) / degFreedomError;
    }

    public double getTotalSumSquares() {
        return this.totalSumSqr;
    }

    public double[] getFreqWeights() {
        return this.freqWeights;
    }

    public double[] getRegWeights() {
        return this.regWeights;
    }

    public double getRegressionSumSquares() {
        return this.regressionSumSqr;
    }

    public double getErrorSumSquares() {
        return this.errorSumSqr;
    }

    public double getDegreesOfFreedomTotal() {
        return this.numRecordsWeighted - 1.0;
    }

    public double getDegreesOfFreedomError() {
        return this.numRecordsWeighted - (double)this.numberOfParameters() - 1.0;
    }

    public double getDegreesOfFreedomRegression() {
        return this.numberOfParameters();
    }

    public double getMeanSquaredError() {
        return this.errorSumSqr / this.getDegreesOfFreedomError();
    }

    public double getMeanSquaredRegression() {
        return this.regressionSumSqr / this.getDegreesOfFreedomRegression();
    }

    public int getNumOfPredictors() {
        return this.regressorMatrix[0].length - 1;
    }

    public double getNumRecordsWeighted() {
        return this.numRecordsWeighted;
    }

    public int getNumRecordsUnweighted() {
        return this.targetValues.length;
    }

    public double[] getCoefficients() {
        return this.coefficients;
    }

    public double etaSquared() {
        return this.getRegressionSumSquares() / this.getTotalSumSquares();
    }

    public double[] getStudentizedResiduals() {
        if (!this.isFit) {
            this.solve();
        }
        RealMatrix hatMatrix = Regression.getProjectionMatrix((RealMatrix)this.matrix);
        double[] studentizedResiduals = new double[this.targetValues.length];
        for (int i = 0; i < this.targetValues.length; ++i) {
            double studentizedResidual;
            double predicted = this.predicted(i);
            double error = this.targetValues[i] - predicted;
            double weight = this.regWeights[i];
            double rse = Math.sqrt(this.getMeanSquaredError());
            double leverage = hatMatrix.getEntry(i, i);
            studentizedResiduals[i] = studentizedResidual = error / (rse * Math.sqrt((1.0 - leverage) / weight));
        }
        return studentizedResiduals;
    }

    public boolean canDecompose() {
        RealMatrix mTm = this.matrix.transpose().multiply((RealMatrix)this.matrix);
        return new LUDecomposition(mTm).getSolver().isNonSingular();
    }

    public boolean isFit() {
        return this.isFit;
    }

    public RealMatrix designMatrix() {
        return this.matrix;
    }

    public double predicted(double[] row) {
        double v = 0.0;
        for (int i = 0; i < this.coefficients.length; ++i) {
            v += this.coefficients[i] * row[i];
        }
        return v;
    }

    private RealVector makeWeightedResponseVector() {
        ArrayRealVector weightedResponseVector = new ArrayRealVector(this.targetValues);
        for (int row = 0; row < weightedResponseVector.getDimension(); ++row) {
            double weight = Math.sqrt(this.weight(row));
            weightedResponseVector.setEntry(row, this.targetValues[row] * weight);
        }
        return weightedResponseVector;
    }

    private Array2DRowRealMatrix makeWeightedDesignMatrix() {
        Array2DRowRealMatrix weightedMatrix = new Array2DRowRealMatrix(this.regressorMatrix);
        MutatingCellVisitor cellVisitor = (row, column, value) -> {
            double weight = Math.sqrt(this.weight(row));
            return value * weight;
        };
        weightedMatrix.walkInRowOrder((RealMatrixChangingVisitor)cellVisitor);
        return weightedMatrix;
    }

    private void calculateStats(Array2DRowRealMatrix matrix) {
        this.calculateErrorAndRegressionSumOfSquares();
    }

    private static RealMatrix getProjectionMatrix(RealMatrix matrix) {
        RealMatrix mTm = matrix.transpose().multiply(matrix);
        RealMatrix inverse = new LUDecomposition(mTm).getSolver().getInverse();
        return matrix.multiply(inverse).multiply(matrix.transpose());
    }

    private void calculateErrorAndRegressionSumOfSquares() {
        double targetMean = this.getTargetMean();
        double errorSumSquares = 0.0;
        double regressionSumSquares = 0.0;
        double totalSumSquares = 0.0;
        for (int i = 0; i < this.targetValues.length; ++i) {
            double weight = this.weight(i);
            double predicted = this.predicted(i);
            double a = predicted - targetMean;
            regressionSumSquares += weight * a * a;
            double b = this.targetValues[i] - predicted;
            errorSumSquares += weight * (b * b);
            totalSumSquares += weight * Math.pow(this.targetValues[i] - targetMean, 2.0);
        }
        this.errorSumSqr = errorSumSquares;
        this.regressionSumSqr = regressionSumSquares;
        this.totalSumSqr = totalSumSquares;
    }

    private double getTargetMean() {
        double meanTarget = 0.0;
        double nRecordsWeighted = 0.0;
        for (int i = 0; i < this.targetValues.length; ++i) {
            meanTarget += this.weight(i) * this.targetValues[i];
            nRecordsWeighted += this.freqWeights[i];
        }
        this.numRecordsWeighted = nRecordsWeighted;
        return meanTarget /= this.totalWeight;
    }

    private double predicted(int row) {
        return this.predicted(this.regressorMatrix[row]);
    }

    private static double[][] constructRegressorMatrix(double[] targetValues, double[][] inputValues) {
        double[][] xValues = new double[targetValues.length][inputValues[0].length + 1];
        for (int i = 0; i < targetValues.length; ++i) {
            xValues[i][0] = 1.0;
            if (i >= inputValues.length) continue;
            double[] inputs = inputValues[i];
            for (int j = 0; j < inputs.length; ++j) {
                xValues[i][j + 1] = inputs[j];
            }
        }
        return xValues;
    }

    private double weight(int row) {
        return this.regWeights[row] * this.freqWeights[row];
    }

    private double totalWeight() {
        double total = 0.0;
        for (int i = 0; i < this.regWeights.length; ++i) {
            total += this.weight(i);
        }
        return total;
    }

    public static class Builder {
        private double[] targetValues;
        private double[][] inputValues;
        private int degree;
        private double[] countValues;
        private double[] weightValues;

        public Builder(double[] targetValues, double[][] inputValues, int degree) {
            this.targetValues = targetValues;
            this.inputValues = inputValues;
            this.degree = degree;
        }

        public Builder withCounts(double[] countValues) {
            this.countValues = countValues;
            return this;
        }

        public Builder withWeights(double[] weightValues) {
            this.weightValues = weightValues;
            return this;
        }

        public Regression build() {
            int size = this.targetValues.length;
            if (this.countValues == null) {
                this.countValues = Builder.unitVector(size);
            }
            if (this.weightValues == null) {
                this.weightValues = Builder.unitVector(size);
            }
            return new Regression(this.targetValues, this.inputValues, this.countValues, this.weightValues, this.degree);
        }

        private static double[] unitVector(int size) {
            double[] v = new double[size];
            Arrays.fill(v, 1.0);
            return v;
        }
    }
}

