/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.smarts.similarity.classifier.api;

import com.ibm.smarts.conversation.nlu.schema.entity.Entity;
import com.ibm.smarts.core.util.RequestContext;
import com.ibm.smarts.model.builder.ColumnEmbedding;
import com.ibm.smarts.model.builder.SampleExtraFeatures;
import com.ibm.smarts.nlp.embedding.WordEmbedding;
import com.ibm.smarts.schema.Feature;
import com.ibm.smarts.schema.FeatureType;
import com.ibm.smarts.schema.MatchReason;
import com.ibm.smarts.schema.MatchedEntity;
import com.ibm.smarts.schema.MatchedFeature;
import com.ibm.smarts.similarity.classifier.common.api.WordEmbeddingProvider;
import com.ibm.smarts.similarity.classifier.common.core.BigramEncoder;
import com.ibm.smarts.similarity.classifier.common.core.IEncoder;
import com.ibm.smarts.similarity.classifier.common.core.IOovEncoder;
import com.ibm.smarts.similarity.classifier.common.core.KMeansEncoder;
import com.ibm.smarts.similarity.classifier.common.core.NlpUtility;
import com.ibm.smarts.similarity.classifier.common.core.PreProcessor;
import com.ibm.smarts.similarity.classifier.common.core.Token;
import com.ibm.smarts.similarity.classifier.common.core.Tokenizer;
import com.ibm.smarts.similarity.classifier.common.core.VectorUtility;
import com.ibm.smarts.similarity.classifier.common.core.WordEmbeddingUtility;
import com.ibm.smarts.similarity.classifier.common.provider.ISimilarityClassifierProvider;
import com.ibm.smarts.similarity.classifier.core.ScoredColumnEmbedding;
import com.ibm.smarts.similarity.classifier.provider.SimilarityClassifierProvider;
import com.ibm.smarts.store.api.provider.IPersistenceProvider;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ValueToColumnClassifier {
    private static final String SPACE = " ";
    private static final Logger LOGGER = LoggerFactory.getLogger(ValueToColumnClassifier.class);
    private static final String JOIN_TOKENS_PATTERN = ".*(\\s+(or|and)[,;]*\\s+|[,;]+(?!$)).*";
    private static final float EPS = 0.001f;
    private static final float LENGTH_PENALTY = 0.1f;
    private static final double LENGTH_MARGIN = 0.5;
    private static final float OUTLIER_PENALTY = 0.2f;
    private static final int TOP_N = 250;
    private static final float MIN_INCREASE = 0.07f;
    private static final float CONFIDENCE_THR = 0.5f;
    private static final float CONFIDENCE_THR_MERGE = 0.85f;
    private static final float EXACT_MATCH_THRESHOLD = 0.98f;

    private ValueToColumnClassifier() {
    }

    @Deprecated
    public static List<MatchedEntity> classify(RequestContext requestContext, IPersistenceProvider persistenceProvider, String smartsModuleId, String query) {
        WordEmbedding wordEmbedding = WordEmbeddingProvider.getWordEmbedding();
        if (wordEmbedding == null) {
            LOGGER.error("Failed to classify query. Similarity Classifier may not have been initialized properly");
            return Collections.emptyList();
        }
        SimilarityClassifierProvider provider = new SimilarityClassifierProvider(requestContext, persistenceProvider);
        IEncoder kMeansEncoder = ValueToColumnClassifier.getEncoder(wordEmbedding);
        return ValueToColumnClassifier.classify((ISimilarityClassifierProvider)provider, kMeansEncoder, smartsModuleId, query);
    }

    public static List<MatchedEntity> classify(RequestContext requestContext, IPersistenceProvider persistenceProvider, String smartsModuleId, List<Entity> nluEntities) {
        WordEmbedding wordEmbedding = WordEmbeddingProvider.getWordEmbedding();
        if (wordEmbedding == null) {
            LOGGER.error("Failed to classify query. Similarity Classifier may not have been initialized properly");
            return Collections.emptyList();
        }
        SimilarityClassifierProvider provider = new SimilarityClassifierProvider(requestContext, persistenceProvider);
        IEncoder kMeansEncoder = ValueToColumnClassifier.getEncoder(wordEmbedding);
        return ValueToColumnClassifier.classify((ISimilarityClassifierProvider)provider, kMeansEncoder, smartsModuleId, nluEntities);
    }

    @Deprecated
    public static List<MatchedEntity> classify(ISimilarityClassifierProvider provider, IEncoder encoder, String smartsModuleId, String query) {
        long startTime = System.currentTimeMillis();
        List col2vec = provider.getColumnEmbeddings(smartsModuleId);
        List tokenizedQuery = Tokenizer.tokenize((String)query, (boolean)true);
        Map<String, float[]> outlierEmbeddings = ValueToColumnClassifier.encodeOutliers(encoder, col2vec);
        ArrayDeque<IntermediateMatch> intermediateMatchesStack = new ArrayDeque<IntermediateMatch>();
        for (Token token : tokenizedQuery) {
            float[] tokenEmbedding = encoder.encodeQuery(token.getText());
            int len = token.getText().length();
            List<ScoredColumnEmbedding> tokenMatches = ValueToColumnClassifier.findSimilarColumnsTopN(encoder, tokenEmbedding, len, col2vec, outlierEmbeddings, 250);
            intermediateMatchesStack.add(new IntermediateMatch(token, tokenMatches));
        }
        ArrayList<IntermediateMatch> refinedMatches = new ArrayList<IntermediateMatch>();
        while (intermediateMatchesStack.size() >= 2) {
            IntermediateMatch firstEntry = (IntermediateMatch)intermediateMatchesStack.pollFirst();
            IntermediateMatch refinedMatch = ValueToColumnClassifier.forwardMerge(encoder, col2vec, outlierEmbeddings, query, firstEntry, (IntermediateMatch)intermediateMatchesStack.peekFirst());
            if (refinedMatch != null) {
                refinedMatches.add(refinedMatch);
                intermediateMatchesStack.pollFirst();
                if (intermediateMatchesStack.isEmpty()) continue;
                intermediateMatchesStack.addFirst(refinedMatch);
                continue;
            }
            if (refinedMatches.contains(firstEntry)) continue;
            refinedMatches.add(firstEntry);
        }
        if (!intermediateMatchesStack.isEmpty()) {
            refinedMatches.add((IntermediateMatch)intermediateMatchesStack.pollFirst());
        }
        List<MatchedEntity> result = ValueToColumnClassifier.pruneMatches(refinedMatches).stream().map(ValueToColumnClassifier::createMatchedEntity).collect(Collectors.toList());
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("Query=[{}]\nClassified in [{}]ms", (Object)query, (Object)(System.currentTimeMillis() - startTime));
        }
        return result;
    }

    public static List<MatchedEntity> classify(ISimilarityClassifierProvider provider, IEncoder encoder, String smartsModuleId, List<Entity> nluEntities) {
        long startTime = System.currentTimeMillis();
        List col2vec = provider.getColumnEmbeddings(smartsModuleId);
        Map<String, float[]> outlierEmbeddings = ValueToColumnClassifier.encodeOutliers(encoder, col2vec);
        List nluEntitiesFiltered = nluEntities.stream().map(e -> new Token((int)e.getStart(), e.getText())).collect(Collectors.toList());
        ArrayList<IntermediateMatch> intermediateMatches = new ArrayList<IntermediateMatch>();
        for (Token q : nluEntitiesFiltered) {
            IntermediateMatch subQueryMatches = ValueToColumnClassifier.getSubQueryMatches(q, encoder, col2vec, outlierEmbeddings);
            if (subQueryMatches == null) continue;
            intermediateMatches.add(subQueryMatches);
        }
        List<MatchedEntity> result = ValueToColumnClassifier.pruneMatches(intermediateMatches).stream().map(ValueToColumnClassifier::createMatchedEntity).collect(Collectors.toList());
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("[{}] entities were classified in [{}]ms\nThese are the entities:\n", (Object)nluEntities.size(), (Object)(System.currentTimeMillis() - startTime));
            nluEntities.forEach(entity -> LOGGER.debug("Entity: [{}]\n", (Object)entity.getText()));
        }
        return result;
    }

    private static IntermediateMatch getSubQueryMatches(Token subQuery, IEncoder encoder, List<ColumnEmbedding> col2vec, Map<String, float[]> outlierEmbeddings) {
        String subQueryText = (subQuery = PreProcessor.removeTrailingAndLeadingStopWords((Token)subQuery)).getText();
        int lastWordIndex = subQueryText.lastIndexOf(SPACE);
        if (lastWordIndex > 0) {
            String queryColumnName;
            String colName = subQueryText.substring(lastWordIndex + 1);
            String filter = subQueryText.substring(0, lastWordIndex);
            float[] filterEmbedding = encoder.encodeQuery(filter);
            List<ScoredColumnEmbedding> subQueryMatches = ValueToColumnClassifier.findSimilarColumnsTopN(encoder, filterEmbedding, filter.length(), col2vec, outlierEmbeddings, 250);
            if (!subQueryMatches.isEmpty() && (colName.equalsIgnoreCase(queryColumnName = subQueryMatches.get(0).getColumnEmbedding().getId()) || NlpUtility.isPluralOf((String)colName, (String)queryColumnName))) {
                Token token = new Token(subQuery.getStartCharIndex(), filter);
                return new IntermediateMatch(token, subQueryMatches);
            }
        }
        float[] subQueryEmbedding = encoder.encodeQuery(subQueryText);
        int len = subQuery.getText().length();
        List<ScoredColumnEmbedding> subQueryMatches = ValueToColumnClassifier.findSimilarColumnsTopN(encoder, subQueryEmbedding, len, col2vec, outlierEmbeddings, 250);
        Token subQueryToken = new Token(subQuery.getStartCharIndex(), subQuery.getText());
        return new IntermediateMatch(subQueryToken, subQueryMatches);
    }

    private static List<IntermediateMatch> pruneMatches(List<IntermediateMatch> iMatches) {
        List<IntermediateMatch> matchesWithHighScores = iMatches.stream().filter(im -> im.getMatches() != null).filter(im -> ValueToColumnClassifier.getHighestScoringMatch(im) > 0.5f).collect(Collectors.toList());
        for (IntermediateMatch match : matchesWithHighScores) {
            match.setMatches(match.getMatches().stream().filter(im -> im.getScore() > 0.5f).collect(Collectors.toList()));
        }
        return matchesWithHighScores;
    }

    private static float getHighestScoringMatch(IntermediateMatch iMatch) {
        if (iMatch == null) {
            return 0.0f;
        }
        List<ScoredColumnEmbedding> matches = iMatch.getMatches();
        if (matches == null || matches.isEmpty()) {
            return 0.0f;
        }
        return matches.get(0).getScore();
    }

    private static List<String> getMatchedColumnIds(IntermediateMatch iMatch) {
        if (iMatch == null) {
            return Collections.emptyList();
        }
        List<ScoredColumnEmbedding> matches = iMatch.getMatches();
        if (matches == null || matches.isEmpty()) {
            return Collections.emptyList();
        }
        return matches.stream().map(m -> m.getColumnEmbedding().getIdForExpression()).collect(Collectors.toList());
    }

    private static IntermediateMatch forwardMerge(IEncoder encoder, List<ColumnEmbedding> col2vec, Map<String, float[]> outlierEmbeddings, String query, IntermediateMatch m1, IntermediateMatch m2) {
        List<ScoredColumnEmbedding> mergedQueryMatches;
        if (m1.getMatches().isEmpty() || m2.getMatches().isEmpty()) {
            return null;
        }
        float m1Score = ValueToColumnClassifier.getHighestScoringMatch(m1);
        float m2Score = ValueToColumnClassifier.getHighestScoringMatch(m2);
        List<String> m1MatchedColumns = ValueToColumnClassifier.getMatchedColumnIds(m1);
        List<String> m2MatchedColumns = ValueToColumnClassifier.getMatchedColumnIds(m2);
        String mergedQuery = query.substring(m1.getToken().getStartCharIndex(), m2.getToken().getEndCharIndex());
        if (mergedQuery.matches(JOIN_TOKENS_PATTERN)) {
            return null;
        }
        float[] mergedQueryEmbedding = encoder.encodeQuery(mergedQuery);
        int mergedQueryLen = mergedQuery.length();
        if (LOGGER.isDebugEnabled() && (null == mergedQueryEmbedding || VectorUtility.norm((float[])mergedQueryEmbedding) < (double)0.001f)) {
            LOGGER.debug("The encoder returned a trivial vector for the query [{}]. The query may have been comprised of OOVs and stop words, which resulted in a null embedding.", (Object)mergedQuery);
        }
        if ((mergedQueryMatches = ValueToColumnClassifier.findSimilarColumnsTopN(encoder, mergedQueryEmbedding, mergedQueryLen, col2vec, outlierEmbeddings, 250)).isEmpty()) {
            return null;
        }
        float topMergedQueryScore = mergedQueryMatches.get(0).getScore();
        String topMergedQueryColumn = mergedQueryMatches.get(0).getColumnEmbedding().getIdForExpression();
        if (!ValueToColumnClassifier.isColumnName(m1, m2) && (topMergedQueryScore >= 0.85f || topMergedQueryScore >= Math.max(m1Score, m2Score) + 0.07f || topMergedQueryScore >= Math.max(m1Score, m2Score) - 0.07f && (m1MatchedColumns.contains(topMergedQueryColumn) || m2MatchedColumns.contains(topMergedQueryColumn)))) {
            Token mergedToken = new Token(m1.getToken().getStartCharIndex(), mergedQuery);
            return new IntermediateMatch(mergedToken, mergedQueryMatches);
        }
        return null;
    }

    private static MatchedEntity createMatchedEntity(IntermediateMatch iMatch) {
        MatchedEntity entity = new MatchedEntity();
        entity.setCharOffsetBegin(iMatch.getToken().getStartCharIndex());
        entity.setCharOffsetEnd(iMatch.getToken().getEndCharIndex());
        entity.setCoveredText(iMatch.getToken().getText());
        iMatch.getMatches().forEach(match -> {
            MatchedFeature matchedFeature = new MatchedFeature();
            matchedFeature.setFeatureType(FeatureType.COLUMN_NAME);
            matchedFeature.setMatchReason(MatchReason.DATA_VALUE);
            matchedFeature.setConfidence(match.getScore());
            Feature feature = new Feature();
            ColumnEmbedding matchColumnEmbedding = match.getColumnEmbedding();
            feature.setIdForExpression(matchColumnEmbedding.getIdForExpression());
            feature.setDatasetRef(matchColumnEmbedding.getDatasetId());
            feature.setColumnInfoRef(matchColumnEmbedding.getId());
            matchedFeature.setFeature(feature);
            entity.getMatchedFeatures().add(matchedFeature);
        });
        return entity;
    }

    public static List<ScoredColumnEmbedding> findSimilarColumnsTopN(IEncoder encoder, float[] v, int len, List<ColumnEmbedding> col2vec, Map<String, float[]> outlierEmbeddings, int topN) {
        if (v == null) {
            return Collections.emptyList();
        }
        ArrayList<ScoredColumnEmbedding> result = new ArrayList<ScoredColumnEmbedding>();
        for (ColumnEmbedding col : col2vec) {
            float maxScorePerCol = -3.4028235E38f;
            for (float[] centroid : col.getEmbedding()) {
                float score = (float)WordEmbeddingUtility.cosine((float[])centroid, (float[])v);
                if (!(score > maxScorePerCol)) continue;
                maxScorePerCol = score;
            }
            maxScorePerCol -= ValueToColumnClassifier.lengthPenalty(col, len);
            List outliers = col.getOutlier();
            if (outliers != null && !outliers.isEmpty() && outlierEmbeddings != null && !outlierEmbeddings.isEmpty()) {
                float outlierScore = ValueToColumnClassifier.findSimilarOutlier(v, outliers, outlierEmbeddings);
                maxScorePerCol = Math.max(outlierScore, maxScorePerCol);
            }
            ScoredColumnEmbedding scoredColumnEmbedding = new ScoredColumnEmbedding(col, maxScorePerCol);
            result.add(scoredColumnEmbedding);
        }
        return result.stream().sorted(Comparator.comparing(ScoredColumnEmbedding::getScore).reversed()).collect(Collectors.toList()).subList(0, Math.min(topN, result.size()));
    }

    private static float findSimilarOutlier(float[] v, List<String> outliers, Map<String, float[]> outlierEmbeddings) {
        float maxOutlierScore = -3.4028235E38f;
        for (String outlier : outliers) {
            float score;
            if (!outlierEmbeddings.containsKey(outlier) || !((score = (float)WordEmbeddingUtility.cosine((float[])outlierEmbeddings.get(outlier), (float[])v)) > maxOutlierScore)) continue;
            maxOutlierScore = score;
        }
        if (maxOutlierScore < 0.98f) {
            maxOutlierScore -= 0.2f;
        }
        return maxOutlierScore;
    }

    private static float lengthPenalty(ColumnEmbedding col, int len) {
        SampleExtraFeatures extraFeatures = col.getExtraFeatures();
        if (extraFeatures != null) {
            int maxL = (int)((double)extraFeatures.getMaxLength() * 1.5);
            int minL = (int)((double)extraFeatures.getMinLength() * 0.5);
            if (len > maxL || len < minL) {
                return 0.1f;
            }
        }
        return 0.0f;
    }

    private static boolean isColumnName(IntermediateMatch m1, IntermediateMatch m2) {
        String colName = ((ScoredColumnEmbedding)m1.matches.get(0)).getColumnEmbedding().getId();
        String token = m2.getToken().getText();
        return colName.equalsIgnoreCase(token);
    }

    private static IEncoder getEncoder(WordEmbedding wordEmbedding) {
        BigramEncoder bigramEncoder = BigramEncoder.getInstance();
        if (bigramEncoder == null) {
            LOGGER.debug("Failed to get embedding for oov words. Similarity Classifier may not have been initialized properly or oov embedding is disabled");
            return new KMeansEncoder(wordEmbedding);
        }
        return new KMeansEncoder(wordEmbedding, (IOovEncoder)bigramEncoder);
    }

    public static Map<String, float[]> encodeOutliers(IEncoder encoder, List<ColumnEmbedding> col2vec) {
        HashMap<String, float[]> outliersEmbeddings = new HashMap<String, float[]>();
        for (ColumnEmbedding col : col2vec) {
            List outliers = col.getOutlier();
            if (outliers == null || outliers.isEmpty()) continue;
            for (String oulier : outliers) {
                float[] encodeQuery = encoder.encodeQuery(oulier);
                if (encodeQuery == null) continue;
                outliersEmbeddings.put(oulier, encodeQuery);
            }
        }
        return outliersEmbeddings;
    }

    static class IntermediateMatch {
        private Token token;
        private List<ScoredColumnEmbedding> matches;

        public IntermediateMatch(Token token, List<ScoredColumnEmbedding> matches) {
            this.token = token;
            this.matches = matches;
        }

        public Token getToken() {
            return this.token;
        }

        public List<ScoredColumnEmbedding> getMatches() {
            return this.matches;
        }

        public void setMatches(List<ScoredColumnEmbedding> matches) {
            this.matches = matches;
        }
    }
}

