/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.ie.crf;

import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.Arrays;

public class FactorTable {
    private static Redwood.RedwoodChannels log = Redwood.channels(FactorTable.class);
    private final int numClasses;
    private final int windowSize;
    private final double[] table;

    public FactorTable(int numClasses, int windowSize) {
        this.numClasses = numClasses;
        this.windowSize = windowSize;
        this.table = new double[SloppyMath.intPow(numClasses, windowSize)];
        Arrays.fill(this.table, Double.NEGATIVE_INFINITY);
    }

    public FactorTable(FactorTable t) {
        this.numClasses = t.numClasses();
        this.windowSize = t.windowSize();
        this.table = new double[t.size()];
        System.arraycopy(t.table, 0, this.table, 0, t.size());
    }

    public boolean hasNaN() {
        return ArrayMath.hasNaN(this.table);
    }

    public String toProbString() {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; ++i) {
            sb.append(Arrays.toString(this.toArray(i)));
            sb.append(": ");
            sb.append(this.prob(this.toArray(i)));
            sb.append('\n');
        }
        sb.append('}');
        return sb.toString();
    }

    public String toNonLogString() {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; ++i) {
            sb.append(Arrays.toString(this.toArray(i)));
            sb.append(": ");
            sb.append(Math.exp(this.getValue(i)));
            sb.append('\n');
        }
        sb.append('}');
        return sb.toString();
    }

    public <L> String toString(Index<L> classIndex) {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; ++i) {
            sb.append(FactorTable.toString(this.toArray(i), classIndex));
            sb.append(": ");
            sb.append(this.getValue(i));
            sb.append('\n');
        }
        sb.append('}');
        return sb.toString();
    }

    public String toString() {
        StringBuilder sb = new StringBuilder("{\n");
        for (int i = 0; i < this.table.length; ++i) {
            sb.append(Arrays.toString(this.toArray(i)));
            sb.append(": ");
            sb.append(this.getValue(i));
            sb.append('\n');
        }
        sb.append('}');
        return sb.toString();
    }

    private static <L> String toString(int[] array, Index<L> classIndex) {
        ArrayList<L> l = new ArrayList<L>(array.length);
        for (int item : array) {
            l.add(classIndex.get(item));
        }
        return ((Object)l).toString();
    }

    public int[] toArray(int index) {
        int[] indices = new int[this.windowSize];
        for (int i = indices.length - 1; i >= 0; --i) {
            indices[i] = index % this.numClasses;
            index /= this.numClasses;
        }
        return indices;
    }

    private int indexOf(int[] entry) {
        int index = 0;
        for (int item : entry) {
            index *= this.numClasses;
            index += item;
        }
        return index;
    }

    private int indexOf(int[] front, int end) {
        int index = 0;
        for (int item : front) {
            index *= this.numClasses;
            index += item;
        }
        index *= this.numClasses;
        return index += end;
    }

    private int indexOf(int front, int[] end) {
        int index = front;
        for (int item : end) {
            index *= this.numClasses;
            index += item;
        }
        return index;
    }

    private int[] indicesEnd(int[] entries) {
        int index = 0;
        for (int entry : entries) {
            index *= this.numClasses;
            index += entry;
        }
        int[] indices = new int[SloppyMath.intPow(this.numClasses, this.windowSize - entries.length)];
        int offset = SloppyMath.intPow(this.numClasses, entries.length);
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = index;
            index += offset;
        }
        return indices;
    }

    private int indicesFront(int[] entries) {
        int start = 0;
        for (int entry : entries) {
            start *= this.numClasses;
            start += entry;
        }
        int offset = SloppyMath.intPow(this.numClasses, this.windowSize - entries.length);
        return start * offset;
    }

    public int windowSize() {
        return this.windowSize;
    }

    public int numClasses() {
        return this.numClasses;
    }

    public int size() {
        return this.table.length;
    }

    public double totalMass() {
        return ArrayMath.logSum(this.table);
    }

    public double unnormalizedLogProb(int[] label) {
        return this.getValue(label);
    }

    public double logProb(int[] label) {
        return this.unnormalizedLogProb(label) - this.totalMass();
    }

    public double prob(int[] label) {
        return Math.exp(this.unnormalizedLogProb(label) - this.totalMass());
    }

    public double conditionalLogProbGivenPrevious(int[] given, int of) {
        if (given.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbGivenPrevious requires given one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(given));
        }
        int startIndex = this.indicesFront(given);
        double z = ArrayMath.logSum(this.table, startIndex, startIndex + this.numClasses);
        int i = startIndex + of;
        return this.table[i] - z;
    }

    public double[] conditionalLogProbsGivenPrevious(int[] given) {
        if (given.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbsGivenPrevious requires given one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(given));
        }
        double[] result = new double[this.numClasses];
        for (int i = 0; i < this.numClasses; ++i) {
            int index = this.indexOf(given, i);
            result[i] = this.table[index];
        }
        ArrayMath.logNormalize(result);
        return result;
    }

    public double conditionalLogProbGivenFirst(int given, int[] of) {
        if (of.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbGivenFirst requires of one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(of));
        }
        int[] labels = new int[this.windowSize];
        labels[0] = given;
        System.arraycopy(of, 0, labels, 1, this.windowSize - 1);
        double probAll = this.unnormalizedLogProb(labels);
        double probGiven = this.unnormalizedLogProbFront(given);
        return probAll - probGiven;
    }

    public double unnormalizedConditionalLogProbGivenFirst(int given, int[] of) {
        if (of.length != this.windowSize - 1) {
            throw new IllegalArgumentException("unnormalizedConditionalLogProbGivenFirst requires of one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(of));
        }
        int[] labels = new int[this.windowSize];
        labels[0] = given;
        System.arraycopy(of, 0, labels, 1, this.windowSize - 1);
        double probAll = this.unnormalizedLogProb(labels);
        return probAll;
    }

    public double conditionalLogProbGivenNext(int[] given, int of) {
        if (given.length != this.windowSize - 1) {
            throw new IllegalArgumentException("conditionalLogProbGivenNext requires given one less than clique size (" + this.windowSize + ") but was " + Arrays.toString(given));
        }
        int[] label = this.indicesEnd(given);
        double[] masses = new double[label.length];
        for (int i = 0; i < masses.length; ++i) {
            masses[i] = this.table[label[i]];
        }
        double z = ArrayMath.logSum(masses);
        return this.table[this.indexOf(of, given)] - z;
    }

    public double unnormalizedLogProbFront(int[] labels) {
        int startIndex = this.indicesFront(labels);
        int numCellsToSum = SloppyMath.intPow(this.numClasses, this.windowSize - labels.length);
        return ArrayMath.logSum(this.table, startIndex, startIndex + numCellsToSum);
    }

    public double logProbFront(int[] label) {
        return this.unnormalizedLogProbFront(label) - this.totalMass();
    }

    public double unnormalizedLogProbFront(int label) {
        int[] labels = new int[]{label};
        return this.unnormalizedLogProbFront(labels);
    }

    public double logProbFront(int label) {
        return this.unnormalizedLogProbFront(label) - this.totalMass();
    }

    public double unnormalizedLogProbEnd(int[] labels) {
        labels = this.indicesEnd(labels);
        double[] masses = new double[labels.length];
        for (int i = 0; i < masses.length; ++i) {
            masses[i] = this.table[labels[i]];
        }
        return ArrayMath.logSum(masses);
    }

    public double logProbEnd(int[] labels) {
        return this.unnormalizedLogProbEnd(labels) - this.totalMass();
    }

    public double unnormalizedLogProbEnd(int label) {
        int[] labels = new int[]{label};
        return this.unnormalizedLogProbEnd(labels);
    }

    public double logProbEnd(int label) {
        return this.unnormalizedLogProbEnd(label) - this.totalMass();
    }

    public double getValue(int index) {
        return this.table[index];
    }

    public double getValue(int[] label) {
        return this.table[this.indexOf(label)];
    }

    public void setValue(int index, double value) {
        this.table[index] = value;
    }

    public void setValue(int[] label, double value) {
        this.table[this.indexOf((int[])label)] = value;
    }

    public void incrementValue(int[] label, double value) {
        this.incrementValue(this.indexOf(label), value);
    }

    public void incrementValue(int index, double value) {
        int n = index;
        this.table[n] = this.table[n] + value;
    }

    void logIncrementValue(int index, double value) {
        this.table[index] = SloppyMath.logAdd(this.table[index], value);
    }

    public void logIncrementValue(int[] label, double value) {
        this.logIncrementValue(this.indexOf(label), value);
    }

    public void multiplyInFront(FactorTable other) {
        int divisor = SloppyMath.intPow(this.numClasses, this.windowSize - other.windowSize());
        for (int i = 0; i < this.table.length; ++i) {
            int n = i;
            this.table[n] = this.table[n] + other.getValue(i / divisor);
        }
    }

    public void multiplyInEnd(FactorTable other) {
        int divisor = SloppyMath.intPow(this.numClasses, other.windowSize());
        for (int i = 0; i < this.table.length; ++i) {
            int n = i;
            this.table[n] = this.table[n] + other.getValue(i % divisor);
        }
    }

    public FactorTable sumOutEnd() {
        FactorTable ft = new FactorTable(this.numClasses, this.windowSize - 1);
        int sz = ft.size();
        for (int i = 0; i < sz; ++i) {
            ft.table[i] = ArrayMath.logSum(this.table, i * this.numClasses, (i + 1) * this.numClasses);
        }
        return ft;
    }

    public FactorTable sumOutFront() {
        FactorTable ft = new FactorTable(this.numClasses, this.windowSize - 1);
        int stride = ft.size();
        for (int i = 0; i < stride; ++i) {
            ft.setValue(i, ArrayMath.logSum(this.table, i, this.table.length, stride));
        }
        return ft;
    }

    public void divideBy(FactorTable other) {
        for (int i = 0; i < this.table.length; ++i) {
            if (this.table[i] == Double.NEGATIVE_INFINITY && other.table[i] == Double.NEGATIVE_INFINITY) continue;
            int n = i;
            this.table[n] = this.table[n] - other.table[i];
        }
    }

    public static void main(String[] args) {
        int k;
        double t;
        int[] b;
        int j;
        int i;
        int k2;
        int numClasses = 6;
        int cliqueSize = 3;
        System.err.printf("Creating factor table with %d classes and window (clique) size %d%n", numClasses, 3);
        FactorTable ft = new FactorTable(numClasses, 3);
        for (int i2 = 0; i2 < numClasses; ++i2) {
            for (int j2 = 0; j2 < numClasses; ++j2) {
                for (int k3 = 0; k3 < numClasses; ++k3) {
                    int[] b2 = new int[]{i2, j2, k3};
                    ft.setValue(b2, (double)(i2 * 4 + j2 * 2 + k3));
                }
            }
        }
        log.info(ft);
        double normalization = 0.0;
        for (int i3 = 0; i3 < numClasses; ++i3) {
            for (int j3 = 0; j3 < numClasses; ++j3) {
                k2 = 0;
                while (k2 < numClasses) {
                    normalization += ft.unnormalizedLogProb(new int[]{i3, j3, k2++});
                }
            }
        }
        log.info("Normalization Z = " + normalization);
        log.info(ft.sumOutFront());
        FactorTable ft2 = new FactorTable(numClasses, 2);
        for (i = 0; i < numClasses; ++i) {
            for (j = 0; j < numClasses; ++j) {
                b = new int[]{i, j};
                ft2.setValue(b, (double)(i * numClasses + j));
            }
        }
        log.info(ft2);
        for (i = 0; i < numClasses; ++i) {
            for (j = 0; j < numClasses; ++j) {
                b = new int[]{i, j};
                t = 0.0;
                for (k = 0; k < numClasses; ++k) {
                    t += Math.exp(ft.conditionalLogProbGivenPrevious(b, k));
                    System.err.println(k + "|" + i + ',' + j + " : " + Math.exp(ft.conditionalLogProbGivenPrevious(b, k)));
                }
                log.info(t);
            }
        }
        log.info("conditionalLogProbGivenFirst");
        for (int j4 = 0; j4 < numClasses; ++j4) {
            for (k2 = 0; k2 < numClasses; ++k2) {
                b = new int[]{j4, k2};
                t = 0.0;
                for (int i4 = 0; i4 < numClasses; ++i4) {
                    t += ft.unnormalizedConditionalLogProbGivenFirst(i4, b);
                    System.err.println(i4 + "|" + j4 + ',' + k2 + " : " + ft.unnormalizedConditionalLogProbGivenFirst(i4, b));
                }
                log.info(t);
            }
        }
        log.info("conditionalLogProbGivenFirst");
        for (i = 0; i < numClasses; ++i) {
            for (j = 0; j < numClasses; ++j) {
                b = new int[]{i, j};
                t = 0.0;
                for (k = 0; k < numClasses; ++k) {
                    t += ft.conditionalLogProbGivenNext(b, k);
                    System.err.println(i + "," + j + '|' + k + " : " + ft.conditionalLogProbGivenNext(b, k));
                }
                log.info(t);
            }
        }
        numClasses = 2;
        FactorTable ft3 = new FactorTable(numClasses, 3);
        ft3.setValue(new int[]{0, 0, 0}, Math.log(0.25));
        ft3.setValue(new int[]{0, 0, 1}, Math.log(0.35));
        ft3.setValue(new int[]{0, 1, 0}, Math.log(0.05));
        ft3.setValue(new int[]{0, 1, 1}, Math.log(0.07));
        ft3.setValue(new int[]{1, 0, 0}, Math.log(0.08));
        ft3.setValue(new int[]{1, 0, 1}, Math.log(0.16));
        ft3.setValue(new int[]{1, 1, 0}, Math.log(1.0E-50));
        ft3.setValue(new int[]{1, 1, 1}, Math.log(1.0E-50));
        FactorTable ft4 = ft3.sumOutFront();
        log.info(ft4.toNonLogString());
        FactorTable ft5 = ft3.sumOutEnd();
        log.info(ft5.toNonLogString());
    }
}

