/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.tree.summary;

import com.ibm.bi.predict.algorithms.tree.summary.SummaryStats;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.fastpattern.util.IntList;

public class CategoricalSummaryStats
extends SummaryStats {
    private double gini;
    private final int[] frequencies;

    public CategoricalSummaryStats(DataColumn target, IntList rows) {
        super(target, rows.len());
        this.frequencies = CategoricalSummaryStats.makeFrequencies(target, rows);
        this.gini = this.calculateGini();
    }

    private CategoricalSummaryStats(DataColumn target, int len) {
        super(target, len);
        this.frequencies = new int[target.getCategoryCount()];
    }

    private double calculateGini() {
        double sumFrequencySquared = 0.0;
        int[] nArray = this.frequencies;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            double count = nArray[i];
            sumFrequencySquared += count * count;
        }
        return 1.0 - sumFrequencySquared / (double)this.n / (double)this.n;
    }

    @Override
    public double expectedValue() {
        int max = 0;
        for (int i = 1; i < this.frequencies.length; ++i) {
            if (this.frequencies[i] < this.frequencies[max]) continue;
            max = i;
        }
        return max;
    }

    @Override
    public double std() {
        throw new UnsupportedOperationException("Cannot compute standard deviation for categorical node");
    }

    @Override
    public double impurity() {
        return this.gini;
    }

    public String toString() {
        return String.format("mode=%d, (%2.1f%%), #rows=%d", (int)this.expectedValue(), this.modeFraction() * 100.0, this.n);
    }

    @Override
    public int[] categoryFrequencies() {
        return this.frequencies;
    }

    private double modeFraction() {
        return (double)this.frequencies[(int)this.expectedValue()] / (double)this.n;
    }

    private static int[] makeFrequencies(DataColumn target, IntList rows) {
        int[] result = new int[target.getCategoryCount()];
        int len = rows.len();
        for (int i = 0; i < len; ++i) {
            int idx;
            int row = rows.value(i);
            int n = idx = (int)target.getValue(row);
            result[n] = result[n] + 1;
        }
        return result;
    }

    @Override
    public SummaryStats merge(SummaryStats other) {
        CategoricalSummaryStats otherStats = (CategoricalSummaryStats)other;
        CategoricalSummaryStats combStats = new CategoricalSummaryStats(this.target, this.n + otherStats.n);
        for (int i = 0; i < combStats.frequencies.length; ++i) {
            combStats.frequencies[i] = this.frequencies[i] + otherStats.frequencies[i];
        }
        combStats.gini = combStats.calculateGini();
        return combStats;
    }
}

