/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.smarts.similarity.classifier.common.core;

import com.ibm.smarts.model.builder.SampleExtraFeatures;
import com.ibm.smarts.nlp.embedding.WordEmbedding;
import com.ibm.smarts.similarity.classifier.common.core.IEncoder;
import com.ibm.smarts.similarity.classifier.common.core.IOovEncoder;
import com.ibm.smarts.similarity.classifier.common.core.PreProcessor;
import com.ibm.smarts.similarity.classifier.common.core.VectorUtility;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractEncoder
implements IEncoder {
    private static final Logger LOGGER = LoggerFactory.getLogger(AbstractEncoder.class);
    private static final String NULL_EMBEDDINGS_MSG = "The word embeddings seem to have been corrupted. Unable to encode tokens.";
    private static final float SMOOTHING_WEIGHT = 0.7f;
    protected WordEmbedding wordEmbedding;
    protected IOovEncoder oovEncoder = null;

    public AbstractEncoder(WordEmbedding wordEmbedding, IOovEncoder oovEncoder) {
        this.wordEmbedding = wordEmbedding;
        this.oovEncoder = oovEncoder;
    }

    public AbstractEncoder(WordEmbedding wordEmbedding) {
        this.wordEmbedding = wordEmbedding;
    }

    @Override
    public float[] encodeQuery(String sentence) {
        return this.encodeTokensViaSmoothedMax(PreProcessor.prepareSentenceForInference(sentence));
    }

    public float[] encodeTokensViaAveraging(List<String> tokens) {
        if (this.wordEmbedding == null) {
            LOGGER.debug(NULL_EMBEDDINGS_MSG);
            return new float[0];
        }
        HashSet<String> uniqueFilteredTokens = new HashSet<String>();
        ArrayList<float[]> tokenEmbeddings = new ArrayList<float[]>();
        for (String token : tokens) {
            String lowerCasedToken = token.toLowerCase();
            float[] v = this.wordEmbedding.getEmbedding(lowerCasedToken);
            if (v == null && this.oovEncoder != null) {
                v = this.oovEncoder.encode(lowerCasedToken);
            }
            if (v == null || !uniqueFilteredTokens.add(lowerCasedToken)) continue;
            tokenEmbeddings.add(v);
        }
        if (uniqueFilteredTokens.isEmpty()) {
            return null;
        }
        return VectorUtility.avg(tokenEmbeddings);
    }

    public float[] encodeTokensViaMax(List<String> tokens) {
        if (this.wordEmbedding == null) {
            LOGGER.debug(NULL_EMBEDDINGS_MSG);
            return new float[0];
        }
        HashSet<String> uniqueFilteredTokens = new HashSet<String>();
        ArrayList<float[]> tokenEmbeddings = new ArrayList<float[]>();
        for (String token : tokens) {
            String lowerCasedToken = token.toLowerCase();
            float[] v = this.wordEmbedding.getEmbedding(lowerCasedToken);
            if (v == null && this.oovEncoder != null) {
                v = this.oovEncoder.encode(lowerCasedToken);
            }
            if (v == null || !uniqueFilteredTokens.add(lowerCasedToken)) continue;
            tokenEmbeddings.add(v);
        }
        if (uniqueFilteredTokens.isEmpty()) {
            return null;
        }
        return VectorUtility.max(tokenEmbeddings);
    }

    public float[] encodeTokensViaSmoothedMax(List<String> tokens) {
        if (this.wordEmbedding == null) {
            LOGGER.debug(NULL_EMBEDDINGS_MSG);
            return new float[0];
        }
        HashSet<String> uniqueFilteredTokens = new HashSet<String>();
        ArrayList<float[]> tokenEmbeddings = new ArrayList<float[]>();
        for (String token : tokens) {
            String lowerCasedToken = token.toLowerCase();
            float[] v = this.wordEmbedding.getEmbedding(lowerCasedToken);
            if (v == null && this.oovEncoder != null) {
                v = this.oovEncoder.encode(lowerCasedToken);
            }
            if (v == null || !uniqueFilteredTokens.add(lowerCasedToken)) continue;
            tokenEmbeddings.add(v);
        }
        if (uniqueFilteredTokens.isEmpty()) {
            return null;
        }
        if (tokenEmbeddings.size() == 1) {
            return (float[])tokenEmbeddings.get(0);
        }
        float[] max = VectorUtility.max(tokenEmbeddings);
        float[] avg = VectorUtility.avg(tokenEmbeddings);
        return VectorUtility.convex(max, avg, 0.7f);
    }

    @Override
    public SampleExtraFeatures createSampleExtraFeatures(List<String> samples) {
        int minLength = Integer.MAX_VALUE;
        int maxLength = Integer.MIN_VALUE;
        for (String s : samples) {
            int l = s.length();
            if (l < minLength) {
                minLength = l;
            }
            if (l <= maxLength) continue;
            maxLength = l;
        }
        return new SampleExtraFeatures(minLength, maxLength);
    }
}

