/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.explore.relationship;

import com.ibm.bi.predict.algorithms.table.FrequencyChiSquareTest;
import com.ibm.bi.predict.data.DataFrame;
import com.ibm.bi.predict.data.DataPrep;
import com.ibm.bi.predict.data.FrequencyTable;
import com.ibm.bi.predict.data.matrix.Matrix;
import com.ibm.bi.predict.data.matrix.MatrixVectorFactory;
import com.ibm.bi.predict.explore.ExploreContext;
import com.ibm.bi.predict.explore.ExploreParams;
import com.ibm.bi.predict.explore.algorithm.ExploreAlgorithm;
import com.ibm.bi.predict.explore.relationship.RelationshipStrengthGroupingIndexer;
import com.ibm.bi.predict.explore.relationship.RelationshipStrengthsGraphAlgorithm;
import com.ibm.bi.predict.explore.relationship.RelationshipStrengthsResult;
import com.ibm.bi.predict.explore.result.RelationshipResult;
import com.ibm.bi.predict.explore.result.RelationshipStrength;
import com.ibm.bi.predict.result.ExecutionResult;
import com.ibm.bi.predict.result.Message;
import com.ibm.bi.predict.result.MessageCode;
import com.ibm.bi.predict.result.StatusCode;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import com.ibm.bi.predict.utils.Tuple;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class RelationshipStrengthsAlgorithm
extends ExploreAlgorithm<RelationshipStrengthsResult> {
    private static final Logger LOGGER = PredictLoggerFactory.getLogger(RelationshipStrengthsAlgorithm.class);
    public static final double DEFAULT_RATIO_OF_EDGES_TO_NODES = 1.8;
    final String requestId;

    public RelationshipStrengthsAlgorithm(ExploreContext context) {
        super(context);
        this.requestId = context.getExploreRequestId();
    }

    @Override
    public RelationshipStrengthsResult prepareResults(DataPrep dataPrep) {
        RelationshipStrengthsResult result;
        LOGGER.debug("PEXRSA - {} - starting algorithm", (Object)this.requestId);
        LOGGER.perfStart();
        LOGGER.perfLog("PEXRSA - {} - start makeFrame  ", (Object)this.requestId);
        DataFrame df = dataPrep.makeFrame();
        LOGGER.debug("PEXRSA - {} - Finding relationships in {} records", (Object)this.requestId, (Object)df.getNumRows());
        DataframeIndexer dfIndex = new DataframeIndexer(df);
        if (df.getNumFields() == 0) {
            LOGGER.warn("PEXRSA - {} - no usable fields were found, returning empty relationships", (Object)this.requestId);
            result = this.makeEmptyResult();
            result.addWarningMessage(new Message(MessageCode.EXPLORE_RELATIONSHIP_NO_USABLE_FIELDS, this.context.getLocale("locale")));
        } else if (!dfIndex.dfIndexes.containsKey(this.context.getString("target", null))) {
            LOGGER.warn("PEXRSA - {} - requested target was not a usable field, returning empty relationships", (Object)this.requestId);
            result = this.makeEmptyResult();
            result.addWarningMessage(new Message(MessageCode.EXPLORE_RELATIONSHIP_INVALID_TARGET, this.context.getLocale("locale")));
        } else {
            LOGGER.perfLog("PEXRSA - {} - start field pairs", (Object)this.requestId);
            ExecutionResult<FrequencyTable> table = this.computeFieldRelationships(df);
            LOGGER.perfLog("PEXRSA - {} - start graph algorithm", (Object)this.requestId);
            int target = dfIndex.dfIndexes.get(this.context.getString("target", null));
            Tuple<List<Integer>, List<RelationshipStrengthsGraphAlgorithm.GraphLink>> graph = this.runGraphAlgorithm(dfIndex, (FrequencyTable)table.getContent(), target);
            Set nodes = ((List)graph._1).stream().map(ix -> df.getFieldName(ix.intValue())).collect(Collectors.toSet());
            List links = (List)graph._2;
            LOGGER.perfLog("PEXRSA - {} - start grouping results", (Object)this.requestId);
            List<RelationshipResult> results = this.context.getGroupings().stream().map(grouping -> this.computeGroupingResult((ExploreParams.Grouping)grouping, dfIndex, nodes, links)).collect(Collectors.toList());
            result = new RelationshipStrengthsResult(StatusCode.SUCCESS, results, this.context);
            Set inputFields = this.context.getStringList("explore-params.candidates", new ArrayList()).stream().collect(Collectors.toSet());
            inputFields.add(this.context.getString("target", null));
            int nInputFields = inputFields.size();
            int nOutputFields = ((List)graph._1).size();
            result.addWarningMessage(new Message(MessageCode.EXPLORE_RELATIONSHIP_FIELD_COUNTS, Collections.emptyList(), this.context.getLocale("locale"), new Object[]{nInputFields, nOutputFields}));
            result.addMessages(table);
        }
        LOGGER.perfLog("PEXRSA - {} - finished", (Object)this.requestId);
        LOGGER.perfStop();
        return result;
    }

    private RelationshipStrengthsResult makeEmptyResult() {
        String target = this.context.getString("target", null);
        Set<Object> fieldIds = target != null ? Collections.singleton(target) : Collections.emptySet();
        List<RelationshipResult> results = this.context.getGroupings().stream().map(grouping -> new RelationshipResult(RelationshipStrengthGroupingIndexer.createGrouping(grouping.getId(), grouping.getGroups(), fieldIds), new ArrayList<RelationshipStrength>())).collect(Collectors.toList());
        return new RelationshipStrengthsResult(StatusCode.SUCCESS, results, this.context);
    }

    private ExecutionResult<FrequencyTable> computeFieldRelationships(DataFrame df) {
        int nFields = df.getNumFields();
        int maxCategories = RelationshipStrengthsAlgorithm.getMaxCategories(df, nFields);
        FrequencyTable scratchTable = new FrequencyTable(maxCategories, maxCategories);
        FrequencyTable resultTable = new FrequencyTable(nFields, nFields);
        ExecutionResult result = new ExecutionResult(StatusCode.SUCCESS, (Object)resultTable);
        for (int i = 0; i < nFields; ++i) {
            for (int j = i + 1; j < nFields; ++j) {
                this.computeFieldRelationship((ExecutionResult<FrequencyTable>)result, df, i, j, scratchTable);
            }
        }
        return result;
    }

    private static int getMaxCategories(DataFrame df, int nFields) {
        int max = 1;
        for (int i = 0; i < nFields; ++i) {
            int nc = df.getFieldCategories(i);
            if (nc <= max) continue;
            max = nc;
        }
        return max;
    }

    private Tuple<List<Integer>, List<RelationshipStrengthsGraphAlgorithm.GraphLink>> runGraphAlgorithm(DataframeIndexer dfIndex, FrequencyTable table, int target) {
        int nlinks;
        int nfields = table.getNumCols();
        int nnodes = this.context.getIntOpt("explore-params.nnodes").orElse(-1);
        if (nnodes <= 0 || nnodes > nfields) {
            nnodes = nfields;
        }
        if ((nlinks = this.context.getInt("maxRelationships", 0)) <= 0) {
            nlinks = (int)Math.round(1.8 * (double)nnodes);
        }
        double minStrength = this.context.getDouble("minStrength", 0.1);
        RelationshipStrengthsGraphAlgorithm alg = new RelationshipStrengthsGraphAlgorithm(table, target, nnodes, nlinks, minStrength);
        return alg.run();
    }

    private void computeFieldRelationship(ExecutionResult<FrequencyTable> result, DataFrame df, int first, int second, FrequencyTable scratchTable) {
        try {
            double strength = RelationshipStrengthsAlgorithm.strengthMeasure(this.getChiSquareTest(df, first, second, scratchTable));
            FrequencyTable table = (FrequencyTable)result.getContent();
            table.set(first, second, strength);
            table.set(second, first, strength);
        }
        catch (Exception e) {
            String fi = df.getFieldName(first);
            String fj = df.getFieldName(second);
            LOGGER.warn("PEXRSA - " + this.requestId + " - Exception computing chisquare from " + this.context.getLoggableName(fi) + " to " + this.context.getLoggableName(fj), (Throwable)e);
            ArrayList<String> fields = new ArrayList<String>();
            fields.add(fi);
            fields.add(fj);
            result.addWarningMessage(new Message(MessageCode.EXPLORE_RELATIONSHIP_EXCEPTION, fields, this.context.getLocale("locale"), new Object[0]));
        }
    }

    protected static double strengthMeasure(FrequencyChiSquareTest test) {
        return test.computeOverallChiSquare().cramersv;
    }

    private RelationshipResult computeGroupingResult(ExploreParams.Grouping grouping, DataframeIndexer dfIndexer, Set<String> nodeIds, List<RelationshipStrengthsGraphAlgorithm.GraphLink> links) {
        HashMap<Tuple<String, String>, RelationshipStrength> maxStrengths = new HashMap<Tuple<String, String>, RelationshipStrength>();
        RelationshipStrengthGroupingIndexer indexer = new RelationshipStrengthGroupingIndexer(this.context, grouping, dfIndexer.dfIndexes);
        for (RelationshipStrengthsGraphAlgorithm.GraphLink link : links) {
            List<String> aGroups = indexer.indexToGroup(link.a);
            String aField = dfIndexer.dfNames[link.a];
            List<String> bGroups = indexer.indexToGroup(link.b);
            String bField = dfIndexer.dfNames[link.b];
            for (String aGroup : aGroups) {
                for (String bGroup : bGroups) {
                    int v = aGroup.compareTo(bGroup);
                    if (v < 0) {
                        this.updateRelationshipStrength(maxStrengths, aGroup, aField, bGroup, bField, link.strength);
                        continue;
                    }
                    if (v <= 0) continue;
                    this.updateRelationshipStrength(maxStrengths, bGroup, bField, aGroup, aField, link.strength);
                }
            }
        }
        List<RelationshipStrength> strengths = maxStrengths.values().stream().sorted(RelationshipStrengthsAlgorithm::sortStrengths).collect(Collectors.toList());
        return new RelationshipResult(indexer.createGrouping(nodeIds), strengths);
    }

    private void updateRelationshipStrength(Map<Tuple<String, String>, RelationshipStrength> maxStrengths, String aGroup, String aField, String bGroup, String bField, double strength) {
        RelationshipStrength rs;
        Tuple key = new Tuple((Object)aGroup, (Object)bGroup);
        if (maxStrengths.containsKey(key) && (rs = maxStrengths.get(key)).getStrength() >= strength) {
            return;
        }
        maxStrengths.put((Tuple<String, String>)key, new RelationshipStrength(aGroup, bGroup, strength, aField, bField));
    }

    private static int sortStrengths(RelationshipStrength s1, RelationshipStrength s2) {
        int v = Double.compare(s2.getStrength(), s1.getStrength());
        if (v != 0) {
            return v;
        }
        v = s1.getA().compareTo(s2.getA());
        if (v != 0) {
            return v;
        }
        return s1.getB().compareTo(s2.getB());
    }

    private FrequencyChiSquareTest getChiSquareTest(DataFrame df, int first, int second, FrequencyTable scratchTable) {
        FrequencyTable crosstab = FrequencyTable.crosstab((DataFrame)df, (int)first, (int)second);
        Matrix Matrix2 = MatrixVectorFactory.matrixFromData((double[][])crosstab.getData());
        return new FrequencyChiSquareTest(Matrix2);
    }

    private static class DataframeIndexer {
        final Map<String, Integer> dfIndexes;
        final String[] dfNames;

        DataframeIndexer(DataFrame df) {
            int nFields = df.getNumFields();
            this.dfIndexes = new HashMap<String, Integer>();
            this.dfNames = new String[nFields];
            for (int i = 0; i < nFields; ++i) {
                String name = df.getFieldName(i);
                this.dfIndexes.put(name, i);
                this.dfNames[i] = name;
            }
        }
    }
}

