/*
 * Decompiled with CFR 0.152.
 */
package edu.stanford.nlp.parser.nndep;

import edu.stanford.nlp.international.Language;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.IndexedWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.parser.nndep.ArcStandard;
import edu.stanford.nlp.parser.nndep.ClassifierFastLoad;
import edu.stanford.nlp.parser.nndep.Config;
import edu.stanford.nlp.parser.nndep.Configuration;
import edu.stanford.nlp.parser.nndep.Dataset;
import edu.stanford.nlp.parser.nndep.DependencyTree;
import edu.stanford.nlp.parser.nndep.ParsingSystem;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.EnglishGrammaticalRelations;
import edu.stanford.nlp.trees.EnglishGrammaticalStructure;
import edu.stanford.nlp.trees.GrammaticalRelation;
import edu.stanford.nlp.trees.GrammaticalStructure;
import edu.stanford.nlp.trees.TreeGraphNode;
import edu.stanford.nlp.trees.TypedDependency;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalRelations;
import edu.stanford.nlp.trees.UniversalEnglishGrammaticalStructure;
import edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalRelations;
import edu.stanford.nlp.trees.international.pennchinese.ChineseGrammaticalStructure;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.RuntimeInterruptedException;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

public class DependencyParserFastLoad {
    private static Redwood.RedwoodChannels log = Redwood.channels((Object[])new Object[]{DependencyParserFastLoad.class});
    public static final String DEFAULT_MODEL = "edu/stanford/nlp/models/parser/nndep/english_UD.gz";
    private List<String> knownWords;
    private List<String> knownPos;
    private List<String> knownLabels;
    private Map<String, Integer> wordIDs;
    private Map<String, Integer> posIDs;
    private Map<String, Integer> labelIDs;
    private List<Integer> preComputed;
    private ClassifierFastLoad classifier;
    private ParsingSystem system;
    private final Config config;
    private final Language language;
    private static final int POS_OFFSET = 18;
    private static final int DEP_OFFSET = 36;
    private static final int STACK_OFFSET = 6;
    private static final int STACK_NUMBER = 6;
    private static final Map<String, Integer> numArgs = new HashMap<String, Integer>();

    public Set<String> getPosSet() {
        Set foo = Generics.newHashSet(this.knownPos);
        foo.remove("-NULL-");
        foo.remove("-UNKNOWN-");
        foo.remove("-ROOT-");
        foo.add(".$$.");
        return Collections.unmodifiableSet(foo);
    }

    DependencyParserFastLoad() {
        this(new Properties());
    }

    public DependencyParserFastLoad(Properties properties) {
        this.config = new Config(properties);
        this.language = this.config.language;
    }

    public int getWordID(String s) {
        return this.wordIDs.containsKey(s) ? this.wordIDs.get(s).intValue() : this.wordIDs.get("-UNKNOWN-").intValue();
    }

    public int getPosID(String s) {
        return this.posIDs.containsKey(s) ? this.posIDs.get(s).intValue() : this.posIDs.get("-UNKNOWN-").intValue();
    }

    public int getLabelID(String s) {
        return this.labelIDs.get(s);
    }

    public List<Integer> getFeatures(Configuration c) {
        int index;
        int j;
        ArrayList<Integer> fWord = new ArrayList<Integer>(18);
        ArrayList<Integer> fPos = new ArrayList<Integer>(18);
        ArrayList<Integer> fLabel = new ArrayList<Integer>(12);
        for (j = 2; j >= 0; --j) {
            index = c.getStack(j);
            fWord.add(this.getWordID(c.getWord(index)));
            fPos.add(this.getPosID(c.getPOS(index)));
        }
        for (j = 0; j <= 2; ++j) {
            index = c.getBuffer(j);
            fWord.add(this.getWordID(c.getWord(index)));
            fPos.add(this.getPosID(c.getPOS(index)));
        }
        for (j = 0; j <= 1; ++j) {
            int k = c.getStack(j);
            int index2 = c.getLeftChild(k);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getRightChild(k);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getLeftChild(k, 2);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getRightChild(k, 2);
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getLeftChild(c.getLeftChild(k));
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
            index2 = c.getRightChild(c.getRightChild(k));
            fWord.add(this.getWordID(c.getWord(index2)));
            fPos.add(this.getPosID(c.getPOS(index2)));
            fLabel.add(this.getLabelID(c.getLabel(index2)));
        }
        ArrayList<Integer> feature = new ArrayList<Integer>(48);
        feature.addAll(fWord);
        feature.addAll(fPos);
        feature.addAll(fLabel);
        return feature;
    }

    private int[] getFeatureArray(Configuration c) {
        int index;
        int j;
        int[] feature = new int[48];
        for (j = 2; j >= 0; --j) {
            index = c.getStack(j);
            feature[2 - j] = this.getWordID(c.getWord(index));
            feature[18 + (2 - j)] = this.getPosID(c.getPOS(index));
        }
        for (j = 0; j <= 2; ++j) {
            index = c.getBuffer(j);
            feature[3 + j] = this.getWordID(c.getWord(index));
            feature[21 + j] = this.getPosID(c.getPOS(index));
        }
        for (j = 0; j <= 1; ++j) {
            int k = c.getStack(j);
            int index2 = c.getLeftChild(k);
            feature[6 + j * 6] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6] = this.getLabelID(c.getLabel(index2));
            index2 = c.getRightChild(k);
            feature[6 + j * 6 + 1] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 1] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 1] = this.getLabelID(c.getLabel(index2));
            index2 = c.getLeftChild(k, 2);
            feature[6 + j * 6 + 2] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 2] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 2] = this.getLabelID(c.getLabel(index2));
            index2 = c.getRightChild(k, 2);
            feature[6 + j * 6 + 3] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 3] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 3] = this.getLabelID(c.getLabel(index2));
            index2 = c.getLeftChild(c.getLeftChild(k));
            feature[6 + j * 6 + 4] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 4] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 4] = this.getLabelID(c.getLabel(index2));
            index2 = c.getRightChild(c.getRightChild(k));
            feature[6 + j * 6 + 5] = this.getWordID(c.getWord(index2));
            feature[24 + j * 6 + 5] = this.getPosID(c.getPOS(index2));
            feature[36 + j * 6 + 5] = this.getLabelID(c.getLabel(index2));
        }
        return feature;
    }

    public Dataset genTrainExamples(List<CoreMap> sents, List<DependencyTree> trees) {
        int numTrans = this.system.numTransitions();
        Dataset ret = new Dataset(48, numTrans);
        IntCounter tokPosCount = new IntCounter();
        log.info(new Object[]{"###################"});
        log.info(new Object[]{"Generate training examples..."});
        for (int i = 0; i < sents.size(); ++i) {
            if (i > 0) {
                if (i % 1000 == 0) {
                    log.info(new Object[]{i + " "});
                }
                if (i % 10000 == 0 || i == sents.size() - 1) {
                    log.info(new Object[0]);
                }
            }
            if (!trees.get(i).isProjective()) continue;
            Configuration c = this.system.initialConfiguration(sents.get(i));
            while (!this.system.isTerminal(c)) {
                int j;
                String oracle = this.system.getOracle(c, trees.get(i));
                List<Integer> feature = this.getFeatures(c);
                ArrayList<Integer> label = new ArrayList<Integer>();
                for (j = 0; j < numTrans; ++j) {
                    String str = (String)this.system.transitions.get(j);
                    if (str.equals(oracle)) {
                        label.add(1);
                        continue;
                    }
                    if (this.system.canApply(c, str)) {
                        label.add(0);
                        continue;
                    }
                    label.add(-1);
                }
                ret.addExample(feature, label);
                for (j = 0; j < feature.size(); ++j) {
                    tokPosCount.incrementCount((Object)(feature.get(j) * feature.size() + j));
                }
                this.system.apply(c, oracle);
            }
        }
        log.info(new Object[]{"#Train Examples: " + ret.n});
        List sortedTokens = Counters.toSortedList((Counter)tokPosCount, (boolean)false);
        this.preComputed = new ArrayList<Integer>(sortedTokens.subList(0, Math.min(this.config.numPreComputed, sortedTokens.size())));
        return ret;
    }

    private void generateIDs() {
        this.wordIDs = new HashMap<String, Integer>();
        this.posIDs = new HashMap<String, Integer>();
        this.labelIDs = new HashMap<String, Integer>();
        int index = 0;
        for (String word : this.knownWords) {
            this.wordIDs.put(word, index++);
        }
        for (String pos : this.knownPos) {
            this.posIDs.put(pos, index++);
        }
        for (String label : this.knownLabels) {
            this.labelIDs.put(label, index++);
        }
    }

    public static DependencyParserFastLoad loadFromModelFile(String modelFile) {
        return DependencyParserFastLoad.loadFromModelFile(modelFile, null);
    }

    public static DependencyParserFastLoad loadFromModelFile(String modelFile, Properties extraProperties) {
        DependencyParserFastLoad parser = extraProperties == null ? new DependencyParserFastLoad() : new DependencyParserFastLoad(extraProperties);
        parser.loadModelFile(modelFile, false);
        return parser;
    }

    public void loadModelFile(String modelFile) {
        this.loadModelFile(modelFile, true);
    }

    private void loadModelFile(String modelFile, boolean verbose) {
        Timing t = new Timing();
        try {
            int i;
            int i2;
            String[] splits;
            int k;
            log.info(new Object[]{"Loading depparse model file: " + modelFile + " ... "});
            BufferedReader input = IOUtils.readerFromString((String)modelFile);
            String s = input.readLine();
            int nDict = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            s = input.readLine();
            int nPOS = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            s = input.readLine();
            int nLabel = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            s = input.readLine();
            int eSize = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            s = input.readLine();
            int hSize = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            s = input.readLine();
            int nTokens = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            s = input.readLine();
            int nPreComputed = Integer.parseInt(s.substring(s.indexOf(61) + 1));
            this.knownWords = new ArrayList<String>();
            this.knownPos = new ArrayList<String>();
            this.knownLabels = new ArrayList<String>();
            double[][] E = new double[nDict + nPOS + nLabel][eSize];
            int index = 0;
            for (k = 0; k < nDict; ++k) {
                s = input.readLine();
                splits = s.split(" ");
                this.knownWords.add(splits[0]);
                for (i2 = 0; i2 < eSize; ++i2) {
                    E[index][i2] = Double.parseDouble(splits[i2 + 1]);
                }
                ++index;
            }
            for (k = 0; k < nPOS; ++k) {
                s = input.readLine();
                splits = s.split(" ");
                this.knownPos.add(splits[0]);
                for (i2 = 0; i2 < eSize; ++i2) {
                    E[index][i2] = Double.parseDouble(splits[i2 + 1]);
                }
                ++index;
            }
            for (k = 0; k < nLabel; ++k) {
                s = input.readLine();
                splits = s.split(" ");
                this.knownLabels.add(splits[0]);
                for (i2 = 0; i2 < eSize; ++i2) {
                    E[index][i2] = Double.parseDouble(splits[i2 + 1]);
                }
                ++index;
            }
            this.generateIDs();
            double[][] W1 = new double[hSize][eSize * nTokens];
            for (int j = 0; j < W1[0].length; ++j) {
                s = input.readLine();
                splits = s.split(" ");
                for (i = 0; i < W1.length; ++i) {
                    W1[i][j] = Double.parseDouble(splits[i]);
                }
            }
            double[] b1 = new double[hSize];
            s = input.readLine();
            splits = s.split(" ");
            for (i = 0; i < b1.length; ++i) {
                b1[i] = Double.parseDouble(splits[i]);
            }
            double[][] W2 = new double[nLabel * 2 - 1][hSize];
            for (int j = 0; j < W2[0].length; ++j) {
                s = input.readLine();
                splits = s.split(" ");
                for (int i3 = 0; i3 < W2.length; ++i3) {
                    W2[i3][j] = Double.parseDouble(splits[i3]);
                }
            }
            this.preComputed = new ArrayList<Integer>();
            while (this.preComputed.size() < nPreComputed) {
                s = input.readLine();
                for (String split : splits = s.split(" ")) {
                    this.preComputed.add(Integer.parseInt(split));
                }
            }
            input.close();
            this.config.hiddenSize = hSize;
            this.config.embeddingSize = eSize;
            this.classifier = new ClassifierFastLoad(this.config, E, W1, b1, W2, this.preComputed);
        }
        catch (IOException e) {
            throw new RuntimeIOException((Throwable)e);
        }
        this.initialize(verbose);
        t.done(log, "Initializing dependency parser");
    }

    private DependencyTree predictInner(CoreMap sentence) {
        int numTrans = this.system.numTransitions();
        Configuration c = this.system.initialConfiguration(sentence);
        while (!this.system.isTerminal(c)) {
            if (Thread.interrupted()) {
                throw new RuntimeInterruptedException();
            }
            double[] scores = this.classifier.computeScores(this.getFeatureArray(c));
            double optScore = Double.NEGATIVE_INFINITY;
            String optTrans = null;
            for (int j = 0; j < numTrans; ++j) {
                if (!(scores[j] > optScore) || !this.system.canApply(c, (String)this.system.transitions.get(j))) continue;
                optScore = scores[j];
                optTrans = (String)this.system.transitions.get(j);
            }
            this.system.apply(c, optTrans);
        }
        return c.tree;
    }

    public GrammaticalStructure predict(CoreMap sentence) {
        if (this.system == null) {
            throw new IllegalStateException("Parser has not been  loaded and initialized; first load a model.");
        }
        DependencyTree result = this.predictInner(sentence);
        List tokens = (List)sentence.get(CoreAnnotations.TokensAnnotation.class);
        ArrayList<TypedDependency> dependencies = new ArrayList<TypedDependency>();
        IndexedWord root = new IndexedWord((Label)new Word("ROOT"));
        root.set(CoreAnnotations.IndexAnnotation.class, (Object)0);
        for (int i = 1; i <= result.n; ++i) {
            int head = result.getHead(i);
            String label = result.getLabel(i);
            IndexedWord thisWord = new IndexedWord((CoreLabel)tokens.get(i - 1));
            IndexedWord headWord = head == 0 ? root : new IndexedWord((CoreLabel)tokens.get(head - 1));
            GrammaticalRelation relation = head == 0 ? GrammaticalRelation.ROOT : this.makeGrammaticalRelation(label);
            dependencies.add(new TypedDependency(relation, headWord, thisWord));
        }
        TreeGraphNode rootNode = new TreeGraphNode((Label)root);
        return this.makeGrammaticalStructure(dependencies, rootNode);
    }

    private GrammaticalRelation makeGrammaticalRelation(String label) {
        switch (this.language) {
            case English: {
                GrammaticalRelation stored = (GrammaticalRelation)EnglishGrammaticalRelations.shortNameToGRel.get(label);
                if (stored == null) break;
                return stored;
            }
            case UniversalEnglish: {
                GrammaticalRelation stored = (GrammaticalRelation)UniversalEnglishGrammaticalRelations.shortNameToGRel.get(label);
                if (stored == null) break;
                return stored;
            }
            case Chinese: {
                GrammaticalRelation stored = (GrammaticalRelation)ChineseGrammaticalRelations.shortNameToGRel.get(label);
                if (stored == null) break;
                return stored;
            }
        }
        return new GrammaticalRelation(this.language, label, null, GrammaticalRelation.DEPENDENT);
    }

    private GrammaticalStructure makeGrammaticalStructure(List<TypedDependency> dependencies, TreeGraphNode rootNode) {
        switch (this.language) {
            case English: {
                return new EnglishGrammaticalStructure(dependencies, rootNode);
            }
            case UniversalEnglish: {
                return new UniversalEnglishGrammaticalStructure(dependencies, rootNode);
            }
            case Chinese: {
                return new ChineseGrammaticalStructure(dependencies, rootNode);
            }
        }
        return new UniversalEnglishGrammaticalStructure(dependencies, rootNode);
    }

    public GrammaticalStructure predict(List<? extends HasWord> sentence) {
        CoreLabel sentenceLabel = new CoreLabel();
        ArrayList<CoreLabel> tokens = new ArrayList<CoreLabel>();
        int i = 1;
        for (HasWord hasWord : sentence) {
            CoreLabel label;
            if (hasWord instanceof CoreLabel) {
                label = (CoreLabel)hasWord;
                if (label.tag() == null) {
                    throw new IllegalArgumentException("Parser requires words with part-of-speech tag annotations");
                }
            } else {
                label = new CoreLabel();
                label.setValue(hasWord.word());
                label.setWord(hasWord.word());
                if (!(hasWord instanceof HasTag)) {
                    throw new IllegalArgumentException("Parser requires words with part-of-speech tag annotations");
                }
                label.setTag(((HasTag)hasWord).tag());
            }
            label.setIndex(i);
            ++i;
            tokens.add(label);
        }
        sentenceLabel.set(CoreAnnotations.TokensAnnotation.class, tokens);
        return this.predict((CoreMap)sentenceLabel);
    }

    private void initialize(boolean verbose) {
        if (this.knownLabels == null) {
            throw new IllegalStateException("Model has not been loaded or trained");
        }
        ArrayList<String> lDict = new ArrayList<String>(this.knownLabels);
        lDict.remove(0);
        this.system = new ArcStandard(this.config.tlp, lDict, verbose);
        if (this.config.numPreComputed > 0) {
            this.classifier.preCompute();
        }
    }

    static {
        numArgs.put("textFile", 1);
        numArgs.put("outFile", 1);
    }
}

