/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.smarts.similarity.classifier.common.core;

import com.ibm.smarts.similarity.classifier.common.core.ClusterLabel;
import com.ibm.smarts.similarity.classifier.common.core.KMeansResult;
import com.ibm.smarts.similarity.classifier.common.core.NormedVector;
import com.ibm.smarts.similarity.classifier.common.core.WordEmbeddingUtility;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KMeans {
    private KMeans() {
    }

    public static KMeansResult fit(int numberOfClusters, List<NormedVector> tokenEmbeddings) {
        List<NormedVector> newCentroids;
        int i;
        long startTime = System.currentTimeMillis();
        Logger LOGGER = LoggerFactory.getLogger(KMeans.class);
        if (numberOfClusters < 1) {
            LOGGER.error("Number of clusters should be at least 1.");
            throw new IllegalArgumentException("Number of clusters is less than 1.");
        }
        if (tokenEmbeddings == null || tokenEmbeddings.isEmpty()) {
            LOGGER.error("No tokens are found.");
            throw new IllegalArgumentException("No tokens are found.");
        }
        int numTokens = tokenEmbeddings.size();
        int numClusters = Math.min(numberOfClusters, numTokens);
        int ndims = tokenEmbeddings.get(0).getVector().length;
        int maxNumIter = 100;
        double threshold = 0.99;
        ClusterLabel[] clusterLabels = new ClusterLabel[numTokens];
        List<NormedVector> centroids = KMeans.initialize(tokenEmbeddings, numClusters);
        int round = 0;
        while (true) {
            for (i = 0; i < numTokens; ++i) {
                clusterLabels[i] = KMeans.labelToken(tokenEmbeddings.get(i), centroids);
            }
            newCentroids = KMeans.updateCentroids(tokenEmbeddings, clusterLabels, numClusters, ndims);
            numClusters = newCentroids.size();
            if (++round >= maxNumIter || KMeans.converged(centroids, newCentroids, threshold)) break;
            centroids = newCentroids;
        }
        centroids = newCentroids;
        for (i = 0; i < numTokens; ++i) {
            clusterLabels[i] = KMeans.labelToken(tokenEmbeddings.get(i), centroids);
        }
        LOGGER.debug("Compute k-means clustering took {}ms. Number of iteration for convergence is {}", (Object)(System.currentTimeMillis() - startTime), (Object)round);
        return new KMeansResult(centroids, clusterLabels);
    }

    private static List<NormedVector> updateCentroids(List<NormedVector> tokens, ClusterLabel[] clusterLabels, int numClusters, int ndims) {
        ArrayList<NormedVector> newc = new ArrayList<NormedVector>();
        for (int i = 0; i < numClusters; ++i) {
            newc.add(new NormedVector(new float[ndims], 0.0));
        }
        int[] counts = new int[numClusters];
        for (int i = 0; i < tokens.size(); ++i) {
            int cn = clusterLabels[i].getLabel();
            ((NormedVector)newc.get(cn)).add(tokens.get(i));
            int n = cn;
            counts[n] = counts[n] + 1;
        }
        for (int c = numClusters - 1; c >= 0; --c) {
            if (counts[c] > 0) {
                ((NormedVector)newc.get(c)).multiply(1.0f / (float)counts[c]);
                continue;
            }
            newc.remove(c);
            --numClusters;
        }
        return newc;
    }

    private static ClusterLabel labelToken(NormedVector token, List<NormedVector> centroids) {
        double maxSimilarity = -1.7976931348623157E308;
        int idx = 0;
        for (int i = 0; i < centroids.size(); ++i) {
            double score = WordEmbeddingUtility.cosine(token, centroids.get(i));
            if (!(maxSimilarity < score)) continue;
            maxSimilarity = score;
            idx = i;
        }
        return new ClusterLabel(idx, maxSimilarity);
    }

    private static boolean converged(List<NormedVector> centroids, List<NormedVector> newCentroids, double threshold) {
        double minSimilarity = Double.MAX_VALUE;
        for (int i = 0; i < newCentroids.size(); ++i) {
            double similarity = WordEmbeddingUtility.cosine(centroids.get(i), newCentroids.get(i));
            if (!(similarity < minSimilarity)) continue;
            minSimilarity = similarity;
        }
        return minSimilarity > threshold;
    }

    private static List<NormedVector> initialize(List<NormedVector> embeddingsList, int numClusters) {
        return embeddingsList.subList(0, numClusters);
    }
}

