/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.algorithms.tree.testutilities;

import com.ibm.bi.predict.algorithms.tree.DecisionTree;
import com.ibm.bi.predict.algorithms.tree.NodeContent;
import com.ibm.bi.predict.algorithms.tree.eval.EvaluationStatistic;
import com.ibm.bi.predict.algorithms.tree.eval.ModelEvaluation;
import com.ibm.bi.predict.algorithms.tree.summary.SummaryStats;
import com.ibm.bi.predict.algorithms.tree.testutilities.LabelGenerator;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.graph.TreeNode;
import java.io.PrintStream;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class TreePrinter {
    private final LabelGenerator columnValueNamer;
    private final PrintStream out;
    private final DecisionTree tree;

    private TreePrinter(DecisionTree tree, LabelGenerator valueOf, PrintStream out) {
        this.tree = tree;
        this.columnValueNamer = valueOf;
        this.out = out;
    }

    public static void print(DecisionTree tree, LabelGenerator namer, PrintStream out) {
        TreePrinter printer = new TreePrinter(tree, namer, out);
        printer.printBasicTreeInfo();
        printer.printStructure();
    }

    public static void toDotFormat(DecisionTree tree, LabelGenerator namer, PrintStream out) {
        TreePrinter printer = new TreePrinter(tree, namer, out);
        printer.outputDotFormat();
    }

    public static void printBasicInfo(DecisionTree tree, LabelGenerator namer, PrintStream out) {
        TreePrinter printer = new TreePrinter(tree, namer, out);
        printer.printBasicTreeInfo();
    }

    public static String importanceToString(Map<DataColumn, Double> map) {
        return map.entrySet().stream().map(e -> String.format("%s:%d%%", e.getKey(), Math.round((Double)e.getValue() * 100.0))).collect(Collectors.joining(", "));
    }

    public static void printSummaryAsCSV(String name, DecisionTree tree, double buildTime, PrintStream out) {
        Double val;
        String statsTest = Stream.of(EvaluationStatistic.values()).map(s -> "test_" + (Object)s).collect(Collectors.joining(","));
        String statsTrain = Stream.of(EvaluationStatistic.values()).map(s -> "train_" + (Object)s).collect(Collectors.joining(","));
        out.println("name,date,target,nodes,leaves,depth,time,stability," + statsTest + "," + statsTrain + ",field1,importance1,field2,importance2,field3,importance3");
        out.print(TreePrinter.q(name));
        out.print("," + new SimpleDateFormat("yyyy-MM-dd'T'hh.mm.ss").format(new Date()));
        out.print("," + TreePrinter.q(tree.target()));
        out.print("," + (int)tree.nodes().count());
        out.print("," + (int)tree.nodes().filter(n -> n.children() == null).count());
        out.print("," + tree.nodes().mapToInt(TreeNode::depth).max().orElse(0));
        out.print("," + buildTime);
        ModelEvaluation evaluation = tree.evaluateOnTestData();
        for (EvaluationStatistic stat : EvaluationStatistic.values()) {
            val = evaluation.valueOf(stat);
            out.print("," + (val == null ? "N/A" : val.toString()));
        }
        evaluation = tree.evaluateOnTrainingData();
        for (EvaluationStatistic stat : EvaluationStatistic.values()) {
            val = evaluation.valueOf(stat);
            out.print("," + (val == null ? "N/A" : val.toString()));
        }
        Map<DataColumn, Double> importance = evaluation.fieldImportance();
        List columns = importance.entrySet().stream().limit(3L).map(Map.Entry::getKey).collect(Collectors.toList());
        for (DataColumn column : columns) {
            out.print("," + TreePrinter.q(column.getId()) + "," + importance.get(column));
        }
        for (int i = columns.size(); i < 3; ++i) {
            out.print(",,");
        }
        out.println();
    }

    private static String q(Object o) {
        return "\"" + o.toString().replaceAll("\"", "'").replaceAll(",", " ") + "\"";
    }

    private String compareCategoricalTarget(SummaryStats stats) {
        return String.format("%10s - %5f [%2d%%]", this.columnValueNamer.nameOf(this.tree.target(), stats.expectedValue()), stats.totalImpurity() / (double)stats.n(), Math.round(100.0 * (double)stats.n() / (double)this.tree.overallStats().n()));
    }

    private String compareNumericTarget(SummaryStats stats) {
        int n;
        String c;
        double sd = Math.sqrt(this.tree.overallStats().totalImpurity() / (double)this.tree.overallStats().n());
        double d = (stats.expectedValue() - this.tree.overallStats().expectedValue()) / sd;
        if (d > 0.0) {
            c = "+";
            n = (int)(2.0 * d - 0.25);
        } else {
            c = "-";
            n = (int)(-2.0 * d - 0.25);
        }
        String diff = "";
        if (n > 0) {
            diff = diff + c;
        }
        if (n > 1) {
            diff = diff + c;
        }
        if (n > 2) {
            diff = diff + c;
        }
        if (n > 3) {
            diff = diff + c;
        }
        if (n > 4) {
            diff = diff + c;
        }
        return String.format("%5s [%2d%%]", diff, Math.round(100.0 * (double)stats.n() / (double)this.tree.overallStats().n()));
    }

    private String getValueComparison(TreeNode<NodeContent> node) {
        if (this.tree.target().getType() == FieldType.CATEGORICAL) {
            return this.compareCategoricalTarget(node.content().stats());
        }
        return this.compareNumericTarget(node.content().stats());
    }

    private String label(TreeNode<NodeContent> node) {
        String base = this.tree.target().getType() == FieldType.CATEGORICAL ? this.columnValueNamer.nameOf(this.tree.target(), node.content().expectedValue()) : String.format("%1.2f", node.content().expectedValue());
        return base + " (" + node.content().rowCount() + ")";
    }

    private void outputDotFormat() {
        this.out.println("digraph " + this.tree.target().getId().replaceAll(" ", "_") + " {");
        HashMap<TreeNode, String> idOf = new HashMap<TreeNode, String>();
        this.tree.nodes().forEach(node -> idOf.put((TreeNode)node, "n" + idOf.size() + "_" + node.depth()));
        idOf.forEach((node, id) -> this.out.println("  " + id + "[label=\"" + this.label((TreeNode<NodeContent>)node) + "\"]"));
        this.out.println();
        idOf.forEach((node, id) -> {
            if (node.parent() == null) {
                return;
            }
            this.out.print("  " + (String)idOf.get(node.parent()) + " -> " + id);
            double wid = (double)((NodeContent)node.content()).rowCount() * 5.0 / (double)this.tree.overallStats().n();
            this.out.println(" [color=gray, penwidth=" + wid + ", headlabel=\"" + this.values((TreeNode<NodeContent>)node, DecisionTree.splitField(node)) + "\", taillabel=\"" + DecisionTree.splitField(node) + "\", fontsize=9, ]");
        });
        this.out.println("}");
    }

    private void printBasicTreeInfo() {
        this.out.println(String.format("Tree for '%s' with %d training rows", this.tree.target().getId(), this.tree.overallStats().n()));
        this.out.println(String.format(" Nodes = %d, Depth = %d", (int)this.tree.nodes().count(), this.tree.nodes().mapToInt(TreeNode::depth).max().orElse(0)));
        this.out.println(" Accuracy on test: " + this.tree.evaluateOnTestData().toString());
        this.out.println(" Accuracy on training: " + this.tree.evaluateOnTrainingData().toString());
        Map<DataColumn, Double> importance = this.tree.evaluateOnTrainingData().fieldImportance();
        this.out.println(" Field Importance: " + TreePrinter.importanceToString(importance));
    }

    private void printStructure() {
        this.tree.nodes().forEach(node -> {
            if (node.parent() != null) {
                this.out.print(this.getValueComparison((TreeNode<NodeContent>)node));
                for (int i = 0; i < node.depth(); ++i) {
                    this.out.print("  ");
                }
                this.out.println(DecisionTree.splitField(node) + " = " + this.values((TreeNode<NodeContent>)node, DecisionTree.splitField(node)));
            }
        });
    }

    private String values(TreeNode<NodeContent> node, DataColumn field) {
        return node.content().splitValues().stream().map(i -> this.columnValueNamer.nameOf(field, (Double)i)).sorted().map(s -> s.trim().isEmpty() ? "<MISSING>" : s).collect(Collectors.joining(", "));
    }
}

