/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.fastpattern.chaid;

import com.ibm.bi.predict.algorithms.tree.DecisionTree;
import com.ibm.bi.predict.algorithms.tree.eval.EvaluationStatistic;
import com.ibm.bi.predict.algorithms.tree.eval.ModelEvaluation;
import com.ibm.bi.predict.algorithms.tree.util.TreeParameters;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.fastpattern.FastPatternContext;
import com.ibm.bi.predict.fastpattern.chaid.serialize.TreeJsonBuilder;
import com.ibm.bi.predict.fastpattern.keydrivers.KeyDriver;
import com.ibm.bi.predict.fastpattern.keydrivers.TargetField;
import com.ibm.bi.predict.fastpattern.result.InputField;
import com.ibm.bi.predict.fastpattern.util.KeyDriverType;
import com.ibm.bi.predict.graph.TreeNode;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.json.JSONObject;

public class Tree
extends KeyDriver {
    private JSONObject structure;
    private double accuracy;
    private double simpleClassificationAccuracy;
    private int nodeCount;
    private int treeMaxDepth;
    private DecisionTree decisionTree;

    public Tree(String id, List<InputField> inputFields, DecisionTree tree, FastPatternContext context) {
        this(id, new TargetField(tree.target().getId(), KeyDriver.toCategoryArray(tree.target().getCategories(), context.getLocale("locale"))), inputFields, tree, context, null);
    }

    public Tree(String id, TargetField targetField, List<InputField> inputFields, DecisionTree tree, FastPatternContext context, Map<Integer, Integer> targetCategorySortConvert) {
        super(id, inputFields, targetField, context);
        ModelEvaluation modelEval;
        this.structure = this.newJsonBuilder(tree, context.getLocale("locale")).build(targetCategorySortConvert);
        if (tree.nodes() != null) {
            this.nodeCount = (int)tree.nodes().count();
            this.treeMaxDepth = tree.nodes().mapToInt(TreeNode::depth).max().getAsInt();
        }
        this.decisionTree = tree;
        double treeConfig = tree.target().getType() == FieldType.NUMERICAL ? TreeParameters.continuous().getTestFraction() : TreeParameters.categorical().getTestFraction();
        ModelEvaluation modelEvaluation = modelEval = treeConfig > 0.0 ? tree.evaluateOnTestData() : tree.evaluateOnTrainingData();
        if (tree.target().getType() == FieldType.NUMERICAL) {
            this.accuracy = modelEval.valueOf(EvaluationStatistic.FRACTION_VARIANCE_EXPLAINED);
        } else {
            this.accuracy = modelEval.valueOf(EvaluationStatistic.CLASSIFICATION_IMPROVEMENT);
            this.simpleClassificationAccuracy = modelEval.valueOf(EvaluationStatistic.CLASSIFICATION_ACCURACY);
        }
    }

    public double getSimpleClassificationAccuracy() {
        return this.simpleClassificationAccuracy;
    }

    @Override
    public double getPValue() {
        return 0.0;
    }

    @Override
    public double getAccuracy() {
        return this.accuracy;
    }

    @Override
    public KeyDriverType getType() {
        return KeyDriverType.TREE;
    }

    public JSONObject getStructure() {
        return this.structure;
    }

    public int getNodeCount() {
        return this.nodeCount;
    }

    public int getMaxDepth() {
        return this.treeMaxDepth;
    }

    public String toString() {
        return "\n Tree: ID: " + this.id + " - Accuracy: " + this.getAccuracy();
    }

    protected TreeJsonBuilder newJsonBuilder(DecisionTree tree, Locale locale) {
        return new TreeJsonBuilder(tree, locale);
    }

    @Override
    public boolean equals(Object o) {
        return o instanceof Tree && this.isTree(o);
    }

    private boolean isTree(Object o) {
        return this.compareTo((Tree)o) == 0;
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.id, this.getAccuracy());
    }

    public DecisionTree getDecisionTree() {
        return this.decisionTree;
    }
}

