/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.tree;

import com.ibm.bi.predict.algorithms.tree.DecisionTreeContext;
import com.ibm.bi.predict.algorithms.tree.NodeContent;
import com.ibm.bi.predict.algorithms.tree.NodePartition;
import com.ibm.bi.predict.algorithms.tree.PartitionNodesMerger;
import com.ibm.bi.predict.algorithms.tree.TreeNodeBuilder;
import com.ibm.bi.predict.algorithms.tree.eval.ModelEvaluation;
import com.ibm.bi.predict.algorithms.tree.summary.SummaryStats;
import com.ibm.bi.predict.algorithms.tree.util.TreeParameters;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.fastpattern.util.IntList;
import com.ibm.bi.predict.fastpattern.util.IntListPool;
import com.ibm.bi.predict.graph.TreeNode;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Objects;
import java.util.function.Function;
import java.util.function.ToDoubleBiFunction;
import java.util.stream.Stream;
import org.apache.commons.lang.mutable.MutableDouble;

public class DecisionTree {
    private static final Logger LOG = PredictLoggerFactory.getLogger(DecisionTree.class);
    private final DataColumn target;
    private final Collection<DataColumn> predictors;
    private final DecisionTreeContext context;
    private final TreeParameters params;
    private final int[] trainingRows;
    private final int[] testingRows;
    private final double minimumInformationGain;
    private final int minNodeSize;
    private final int maxTreeSize;
    private final boolean isPruningEnabled;
    private final IntListPool pool;
    private final PartitionNodesMerger partitionNodesMerger;
    private final TreeNode<NodeContent> root;
    private final DataColumn firstSplit;
    private boolean fitted;

    public static DataColumn splitField(TreeNode<NodeContent> node) {
        return node.content().splitBy();
    }

    public static DecisionTree make(DataColumn target, DataColumn topOneWay, Collection<DataColumn> predictors, DecisionTreeContext context) {
        int[][] partitionedRows = ModelEvaluation.partitionRowsIntoTrainingAndTest(target.rowCount(), context.treeParametersForTargetType(target.getType()));
        return new DecisionTree(target, topOneWay, predictors, context, partitionedRows[0], partitionedRows[1]);
    }

    public static DecisionTree make(DataColumn target, DataColumn topOneWay, Collection<DataColumn> predictors, int[] trainingRows, int[] testingRows, DecisionTreeContext context) {
        return new DecisionTree(target, topOneWay, predictors, context, trainingRows, testingRows);
    }

    static double totalImpurityReduction(TreeNode<NodeContent> node) {
        MutableDouble leafImpurity = new MutableDouble();
        node.walk(n -> {
            if (n.isLeaf()) {
                leafImpurity.add(((NodeContent)n.content()).totalImpurity());
            }
        }, true);
        return 1.0 - leafImpurity.doubleValue() / node.content().totalImpurity();
    }

    private DecisionTree(DataColumn target, DataColumn oneWay, Collection<DataColumn> predictors, DecisionTreeContext context, int[] trainingRows, int[] testingRows) {
        this.target = target;
        this.predictors = predictors;
        this.trainingRows = trainingRows;
        this.testingRows = testingRows;
        this.context = context;
        this.params = context.treeParametersForTargetType(this.target().getType());
        this.pool = new IntListPool();
        this.root = TreeNodeBuilder.makeRoot(target, IntList.wrapping((int[])trainingRows));
        this.minNodeSize = this.params.minNodeSize(this.root.content().rowCount());
        this.minimumInformationGain = this.params.minimumInformationGain(this.root.content().totalImpurity());
        this.maxTreeSize = this.params.maxTreeSize();
        this.isPruningEnabled = this.params.isPruningEnabled();
        this.firstSplit = oneWay != null && predictors.contains(oneWay) ? oneWay : null;
        this.partitionNodesMerger = new PartitionNodesMerger(target, this.minNodeSize, this.params.minimumInformationGainForMerge(this.root.content().totalImpurity()), this.pool);
    }

    public DecisionTree(TreeNode<NodeContent> root, DataColumn target) {
        this.target = target;
        this.predictors = null;
        this.trainingRows = null;
        this.testingRows = null;
        this.context = null;
        this.params = null;
        this.pool = null;
        this.minNodeSize = 0;
        this.minimumInformationGain = 0.0;
        this.maxTreeSize = 0;
        this.isPruningEnabled = false;
        this.partitionNodesMerger = null;
        this.root = root;
        this.firstSplit = null;
        this.fitted = true;
    }

    public DecisionTreeContext context() {
        return this.context;
    }

    public ModelEvaluation evaluateOnTestData() {
        this.ensureFitted();
        return new ModelEvaluation(this, this.testingRows);
    }

    public ModelEvaluation evaluateOnTrainingData() {
        this.ensureFitted();
        return new ModelEvaluation(this, this.trainingRows);
    }

    public TreeNode<NodeContent> findLeaf(Function<String, Object> rowData) {
        return this.findLeaf(this.root, rowData);
    }

    public DecisionTree fit() {
        if (this.fitted) {
            throw new IllegalStateException("Tree was already fitted to the data");
        }
        long start = System.currentTimeMillis();
        List<TreeNode<NodeContent>> forSplitting = this.buildTree();
        if (this.isPruningEnabled) {
            this.pruneTree();
        }
        if (!forSplitting.isEmpty()) {
            LOG.debug("Tree exceeded limit on number of nodes and fitting was terminated early");
        }
        this.fitted = true;
        long elapsed = System.currentTimeMillis() - start;
        LOG.debug("Built tree of size " + this.nodes().count() + " in " + (double)elapsed / 1000.0 + "s");
        return this;
    }

    public boolean fitted() {
        return this.fitted;
    }

    public DataColumn getFirstSplit() {
        return this.firstSplit;
    }

    public int[] getTrainingRows() {
        return this.trainingRows;
    }

    public Stream<TreeNode<NodeContent>> leafNodes() {
        this.checkFitted();
        return this.root.descendants().stream().filter(TreeNode::isLeaf);
    }

    public Stream<TreeNode<NodeContent>> nodes() {
        this.checkFitted();
        return this.root.descendants().stream();
    }

    public SummaryStats overallStats() {
        return this.root.content().stats();
    }

    public TreeParameters params() {
        return this.params;
    }

    public Collection<DataColumn> predictors() {
        return this.predictors;
    }

    public TreeNode<NodeContent> rootNode() {
        return this.root;
    }

    public double[] score(ToDoubleBiFunction<String, Integer> predictorValues, int[] rows) {
        this.checkFitted();
        double[] scores = new double[rows.length];
        for (int i = 0; i < scores.length; ++i) {
            int row = rows[i];
            TreeNode<NodeContent> leaf = this.findLeaf(s -> predictorValues.applyAsDouble((String)s, row));
            scores[i] = leaf.content().expectedValue();
        }
        return scores;
    }

    public DataColumn target() {
        return this.target;
    }

    public DecisionTree withDifferentPredictors(Collection<DataColumn> predictors) {
        return new DecisionTree(this.target, this.firstSplit, predictors, this.context, this.trainingRows, this.testingRows);
    }

    TreeNode<NodeContent>[] makeBestSplit(TreeNode<NodeContent> node) {
        LOG.debug(() -> "Making best split for node " + node.toString());
        NodePartition best = this.findBestPartition(node);
        if (best == null) {
            LOG.trace(() -> "Making best split found no suitable partition");
            return null;
        }
        node.setChildren(best.parts);
        TreeNode[] result = best.parts;
        LOG.debug(() -> "Making best split used " + best + ", resulting in " + result.length + " new nodes");
        return result;
    }

    NodePartition makePartition(TreeNode<NodeContent> node, DataColumn splittingColumn) {
        LOG.debug(() -> "Making partition for node " + node.toString() + " using " + splittingColumn.toString());
        TreeNode[] parts = TreeNodeBuilder.splitByColumnValues(node, splittingColumn, this.pool);
        if (parts == null) {
            LOG.trace(() -> "Making partition did not find a valid partition");
            return null;
        }
        LOG.trace(() -> "Making partition created " + parts.length + " parts");
        NodePartition partition = new NodePartition(splittingColumn, parts, node);
        if (partition.informationGain < this.minimumInformationGain) {
            partition.releaseListsToPool(this.pool);
            return null;
        }
        NodePartition nodePartition = this.partitionNodesMerger.merge(partition);
        LOG.trace(() -> "Making partition resulted in " + nodePartition);
        return nodePartition;
    }

    private List<TreeNode<NodeContent>> buildTree() {
        List forSplitting = Collections.singletonList(this.root);
        LOG.debug(() -> String.format("Building tree for %s using %d predictors", this.target.getId(), this.predictors.size()));
        LOG.debug(() -> String.format("Predictors = %s", this.predictors.toString()));
        int iteration = 1;
        while (iteration < 100 && !forSplitting.isEmpty()) {
            int oldSize = forSplitting.size();
            forSplitting = this.streamOf(forSplitting).filter(this::isSplitCandidate).map(this::makeBestSplit).filter(Objects::nonNull).collect(ArrayList::new, Collections::addAll, ArrayList::addAll);
            if (this.root.descendants().size() > this.maxTreeSize * 2) break;
            int idx = iteration++;
            int newSize = forSplitting.size();
            LOG.debug(() -> String.format("Tree Iteration #%d: #nodes=%d, impurity=%2.5f, remaining nodes was %d and now is %d", idx, this.root.descendants().size(), DecisionTree.totalImpurityReduction(this.root), oldSize, newSize));
        }
        return forSplitting;
    }

    private void checkFitted() {
        if (!this.fitted) {
            throw new IllegalStateException("Tree must be fitted first");
        }
    }

    private void ensureFitted() {
        if (!this.fitted) {
            this.fit();
        }
    }

    private NodePartition findBestPartition(TreeNode<NodeContent> node) {
        NodePartition best = null;
        for (DataColumn p : this.predictors) {
            NodePartition candidate;
            if (!this.isCandidatePredictor(p, node) || (candidate = this.makePartition(node, p)) == null) continue;
            if (best == null || best.compareTo(candidate) < 0) {
                best = candidate;
                continue;
            }
            candidate.releaseListsToPool(this.pool);
        }
        return best;
    }

    private TreeNode<NodeContent> findLeaf(TreeNode<NodeContent> node, Function<String, Object> rowData) {
        if (node.isLeaf()) {
            return node;
        }
        for (TreeNode<NodeContent> c : node.children()) {
            Object val = rowData.apply(c.content().splitBy().getId());
            if (!c.content().containsValue(val)) continue;
            return this.findLeaf(c, rowData);
        }
        return node;
    }

    private boolean isCandidatePredictor(DataColumn p, TreeNode<NodeContent> node) {
        return node.equals(this.root) && this.firstSplit != null ? p == this.firstSplit : p != DecisionTree.splitField(node);
    }

    private boolean isPruneCandidate(TreeNode<NodeContent> node) {
        if (node.isLeaf()) {
            return false;
        }
        for (TreeNode<NodeContent> child : node.children()) {
            if (child.isLeaf()) continue;
            return false;
        }
        return true;
    }

    private boolean isSplitCandidate(TreeNode<NodeContent> node) {
        return node.content().rowCount() >= this.minNodeSize * 2 && node.content().totalImpurity() >= this.minimumInformationGain;
    }

    private void pruneTree() {
        int removed;
        for (int count = this.root.descendants().size(); count > this.maxTreeSize; count -= removed) {
            TreeNode forPruning = this.root.descendants().stream().filter(this::isPruneCandidate).min(Comparator.comparing(DecisionTree::totalImpurityReduction)).orElseThrow(NoSuchElementException::new);
            removed = forPruning.children().length;
            forPruning.setChildren(null);
        }
    }

    private <T> Stream<T> streamOf(Collection<T> nodes) {
        return this.context.getBoolean("useParallelProcessing", false) ? nodes.parallelStream() : nodes.stream();
    }
}

