/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.fastpattern.chaid;

import com.ibm.bi.predict.algorithms.tree.DecisionTree;
import com.ibm.bi.predict.algorithms.tree.DecisionTreeContext;
import com.ibm.bi.predict.algorithms.tree.eval.EvaluationStatistic;
import com.ibm.bi.predict.algorithms.tree.eval.ModelEvaluation;
import com.ibm.bi.predict.algorithms.tree.util.TreeParameters;
import com.ibm.bi.predict.data.Category;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.data.DataPrep;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.fastpattern.FastPatternContext;
import com.ibm.bi.predict.fastpattern.algorithm.Algorithm;
import com.ibm.bi.predict.fastpattern.chaid.Tree;
import com.ibm.bi.predict.fastpattern.chaid.util.TreeUtils;
import com.ibm.bi.predict.fastpattern.keydrivers.KeyDriver;
import com.ibm.bi.predict.fastpattern.keydrivers.OneWayKeyDriver;
import com.ibm.bi.predict.fastpattern.keydrivers.TargetField;
import com.ibm.bi.predict.fastpattern.result.AlgorithmResult;
import com.ibm.bi.predict.fastpattern.result.InputField;
import com.ibm.bi.predict.graph.TreeNode;
import com.ibm.bi.predict.result.StatusCode;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

public class TreeAlgorithm
extends Algorithm<AlgorithmResult> {
    private static final Logger LOG = PredictLoggerFactory.getLogger(TreeAlgorithm.class);
    private final DataColumn topOneWay;

    public TreeAlgorithm(DataPrep dataPrep, FastPatternContext context, List<OneWayKeyDriver> oneway) {
        super(dataPrep, context, oneway);
        DataColumn topOneWayColumn = (DataColumn)dataPrep.driverColumns().get(oneway.get(0).getFactorIndex() - 1);
        double topOneWayAccuracy = oneway.get(0).getAccuracy();
        LOG.debug("Top oneway: {}, Accuracy: {}", (Object)topOneWayColumn, (Object)topOneWayAccuracy);
        FieldType targetType = dataPrep.targetColumn().getType();
        this.topOneWay = targetType == FieldType.NUMERICAL && topOneWayAccuracy > TreeParameters.continuous().minAccuracy() || targetType == FieldType.CATEGORICAL && topOneWayAccuracy > TreeParameters.categorical().minAccuracy() ? topOneWayColumn : null;
    }

    @Override
    public AlgorithmResult getResults() {
        LOG.perfStart();
        LOG.perfLog("Starting tree computation.");
        List<DataColumn> predictors = this.oneway.stream().map(v -> (DataColumn)this.dataPrep.driverColumns().get(v.getFactorIndex() - 1)).collect(Collectors.toList());
        String typeAsString = this.dataPrep.targetColumn().getType().toString().toLowerCase();
        LOG.debug("Defining Decision tree for {} target with {} predictors", (Object)typeAsString, (Object)predictors.size());
        LOG.debug("predictors = {}", predictors);
        DecisionTree tree = this.rebuildIfNeeded(predictors);
        if (tree.nodes().count() > 1L) {
            LOG.debug("First split of the tree: {}", (Object)DecisionTree.splitField((TreeNode)tree.rootNode().children()[0]));
        }
        LOG.perfStop("Finished tree computation.");
        return this.buildResult(tree);
    }

    private DecisionTree rebuildIfNeeded(List<DataColumn> predictors) {
        DecisionTree tree = DecisionTree.make((DataColumn)this.dataPrep.targetColumn(), (DataColumn)this.topOneWay, predictors, (DecisionTreeContext)TreeUtils.decisionTreeContext(this.context)).fit();
        List<DataColumn> weakPredictors = this.predictorsWithWeakImportance(tree);
        LOG.debug("Number of predictors with zero importance removed: {}, Tree Stats: {}", (Object)weakPredictors.size(), (Object)tree.evaluateOnTrainingData());
        if (!weakPredictors.isEmpty()) {
            List<DataColumn> usefulPredictors = this.getUsedFieldInTree(tree, predictors);
            usefulPredictors.remove(TreeAlgorithm.findFieldToRemove(weakPredictors, predictors));
            tree = this.rebuildIfNeeded(usefulPredictors);
        }
        return tree;
    }

    private static DataColumn findFieldToRemove(List<DataColumn> fieldsWithZeroImportance, List<DataColumn> predictors) {
        for (int i = predictors.size() - 1; i >= 0; --i) {
            DataColumn field = predictors.get(i);
            if (!fieldsWithZeroImportance.contains(field)) continue;
            return field;
        }
        return fieldsWithZeroImportance.get(0);
    }

    private List<DataColumn> getUsedFieldInTree(DecisionTree tree, List<DataColumn> predictorsInCurrentTree) {
        Set used = tree.nodes().map(DecisionTree::splitField).filter(Objects::nonNull).collect(Collectors.toSet());
        predictorsInCurrentTree.retainAll(used);
        return predictorsInCurrentTree;
    }

    private List<DataColumn> predictorsWithWeakImportance(DecisionTree tree) {
        ModelEvaluation eval = tree.evaluateOnTrainingData();
        double thresholdImp = TreeAlgorithm.calculateCutOffThreshold(tree, eval);
        return eval.fieldImportance().entrySet().stream().filter(e -> (Double)e.getValue() <= thresholdImp).peek(s -> LOG.debug("\tPredictor with importance less than threshold: {}", s)).map(Map.Entry::getKey).collect(Collectors.toList());
    }

    private static double calculateCutOffThreshold(DecisionTree tree, ModelEvaluation eval) {
        TreeParameters categoricalTreeParams;
        double treeAcc;
        double thresholdImp = 0.0;
        double nRows = tree.getTrainingRows().length;
        if (tree.target().getType().equals((Object)FieldType.CATEGORICAL) && (treeAcc = eval.valueOf(EvaluationStatistic.CLASSIFICATION_IMPROVEMENT).doubleValue()) + nRows / 10000.0 < (categoricalTreeParams = tree.params()).getPredictorImportanceMultipler()) {
            double a = categoricalTreeParams.getPredictorImportanceMultipler() - treeAcc - nRows / 10000.0;
            thresholdImp = categoricalTreeParams.getPredictorImportanceCutoff() + a * categoricalTreeParams.getPredictorImportanceAdjustment();
        }
        return thresholdImp;
    }

    private AlgorithmResult buildResult(DecisionTree tree) {
        LOG.perfStart();
        LOG.perfLog("Starting to build tree result.");
        List<InputField> inputFields = this.buildInputFields(tree);
        List<Category> targetCategories = this.getSortedTargetCategories(tree.target().getCategories());
        Map<Integer, Integer> categoryIndexMap = targetCategories == null ? null : this.getTargetCategoryIndexMap(targetCategories);
        String[] targetLabels = targetCategories == null ? null : this.getTargetCategoryLabels(targetCategories);
        Tree keyDriver = new Tree(TreeAlgorithm.buildTreeId(inputFields), new TargetField(tree.target().getId(), targetLabels), inputFields, tree, this.context, categoryIndexMap);
        LOG.perfStop("Finished building tree result.");
        if (((KeyDriver)keyDriver).getAccuracy() > tree.params().minAccuracy()) {
            return new AlgorithmResult(StatusCode.SUCCESS, Arrays.asList(keyDriver));
        }
        LOG.debug("Tree accuracy less or equal to 0.1 and filtered out: {}, target: {}", (Object)keyDriver, (Object)keyDriver.getTargetId());
        return new AlgorithmResult(StatusCode.SUCCESS, Collections.emptyList());
    }

    private List<Category> getSortedTargetCategories(List<Category> targetCategories) {
        if (targetCategories == null) {
            return null;
        }
        for (int i = 0; i < targetCategories.size(); ++i) {
            targetCategories.get(i).setSortIndex(i);
        }
        return targetCategories.stream().sorted(this.context.getTargetCategorySortingStrategy()).collect(Collectors.toList());
    }

    private Map<Integer, Integer> getTargetCategoryIndexMap(List<Category> sortedCategories) {
        HashMap<Integer, Integer> categoryIndexMap = new HashMap<Integer, Integer>();
        for (int i = 0; i < sortedCategories.size(); ++i) {
            categoryIndexMap.put(i, sortedCategories.get(i).getSortIndex());
        }
        return categoryIndexMap;
    }

    private String[] getTargetCategoryLabels(List<Category> sortedCategories) {
        Locale locale = this.context.getLocale("locale");
        return (String[])sortedCategories.stream().map(c -> c.asString(locale)).toArray(String[]::new);
    }

    private static String buildTreeId(List<InputField> inputFields) {
        StringBuilder idBuilder = new StringBuilder("tree");
        inputFields.forEach(inputField -> idBuilder.append("_" + inputField.getId()));
        return idBuilder.toString();
    }

    private List<InputField> buildInputFields(DecisionTree tree) {
        ModelEvaluation modelEvaluation = tree.evaluateOnTrainingData();
        Map fieldImportance = modelEvaluation.fieldImportance();
        ArrayList<InputField> inputFields = new ArrayList<InputField>();
        fieldImportance.forEach((dataColumn, importance) -> this.addInputField((List<InputField>)inputFields, (DataColumn)dataColumn, (Double)importance));
        return inputFields;
    }

    private void addInputField(List<InputField> inputFields, DataColumn dataColumn, Double importance) {
        inputFields.add(new InputField(dataColumn.getId(), dataColumn.getLabel(), importance, this.getCategoryLabels(dataColumn)));
    }

    private List<String> getCategoryLabels(DataColumn dataColumn) {
        if (dataColumn.getType() == FieldType.NUMERICAL) {
            return Collections.emptyList();
        }
        ArrayList<String> categoryLabels = new ArrayList<String>();
        List categories = dataColumn.getCategories();
        for (Category category : categories) {
            categoryLabels.add(category.asString(this.context.getLocale("locale")));
        }
        return categoryLabels;
    }
}

