/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.table;

import com.ibm.bi.predict.data.matrix.Matrix;
import java.util.Arrays;
import org.apache.commons.lang.mutable.MutableDouble;

public class AdjustedCountR2 {
    private final Matrix crosstab;
    private final boolean[] includedRows;
    private final boolean[] includedColumns;
    private double maxOfRowTotals;
    private double sumOfColumnMaxes;
    private double totalRowCount;

    public AdjustedCountR2(Matrix crosstab) {
        this(crosstab, 0);
    }

    public AdjustedCountR2(Matrix crosstab, int minTotal) {
        this.crosstab = crosstab;
        this.includedRows = new boolean[crosstab.rowDimension()];
        Arrays.fill(this.includedRows, true);
        this.includedColumns = new boolean[crosstab.columnDimension()];
        Arrays.fill(this.includedColumns, true);
        if (minTotal > 0) {
            this.filterRowsAndColumns(minTotal);
        }
        this.computeStatistics();
    }

    private void filterRowsAndColumns(int minTotal) {
        double[] rowTotals = this.crosstab.rowTotals();
        for (int i = 0; i < rowTotals.length; ++i) {
            if (!(rowTotals[i] < (double)minTotal)) continue;
            this.includedRows[i] = false;
        }
        double[] colTotals = this.crosstab.columnTotals();
        for (int i = 0; i < colTotals.length; ++i) {
            if (!(colTotals[i] < (double)minTotal)) continue;
            this.includedColumns[i] = false;
        }
    }

    public double getAdjustedCountR2() {
        return (this.sumOfColumnMaxes - this.maxOfRowTotals) / (this.totalRowCount - this.maxOfRowTotals);
    }

    public double getAccuracy() {
        return this.sumOfColumnMaxes / this.totalRowCount;
    }

    private void computeStatistics() {
        MutableDouble totalRecordCount = new MutableDouble(0.0);
        double[] rowTotals = new double[this.crosstab.rowDimension()];
        double[] columnMaxes = new double[this.crosstab.columnDimension()];
        this.crosstab.walkNonZero((r, c, v) -> {
            if (this.includedRows[r] && this.includedColumns[c]) {
                totalRecordCount.add(v);
                int n = r;
                rowTotals[n] = rowTotals[n] + v;
                columnMaxes[c] = Math.max(columnMaxes[c], v);
            }
        });
        this.maxOfRowTotals = Arrays.stream(rowTotals).max().getAsDouble();
        this.sumOfColumnMaxes = Arrays.stream(columnMaxes).sum();
        this.totalRowCount = totalRecordCount.doubleValue();
    }
}

