/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.logisticregression;

import com.ibm.smarts.generic.recommender.api.java.classifiers.ClassifierException;
import com.ibm.smarts.generic.recommender.api.java.classifiers.ClassifierType;
import com.ibm.smarts.generic.recommender.api.java.classifiers.DatasetException;
import com.ibm.smarts.generic.recommender.api.java.classifiers.DatasetFeatureDescriptorException;
import com.ibm.smarts.generic.recommender.api.java.classifiers.IClassification;
import com.ibm.smarts.generic.recommender.api.java.classifiers.SupportedFileFormats;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.Classification;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.Classifier;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.datalayer.DatasetManager;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.datalayer.GenericTransformerException;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.logisticregression.LogisticRegressionClassifier;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.logisticregression.Scheduler;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.logisticregression.Task;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OvALogisticRegressionClassifier
extends Classifier {
    private static final long serialVersionUID = -2768175376488908000L;
    private static final Logger LOGGER = LoggerFactory.getLogger(OvALogisticRegressionClassifier.class);
    private static final int TRAINING_BATCH_SIZE = 700;
    private DatasetManager mgr = new DatasetManager();
    private transient Map<Double, LogisticRegressionClassifier> classifiersMap = new HashMap<Double, LogisticRegressionClassifier>();

    public OvALogisticRegressionClassifier() {
        super(ClassifierType.OVA_LOGISTIC_REGRESSION);
    }

    private DatasetManager getBinaryDatasetMgrForLabel(double label) {
        return this.mgr.convertToBinaryDataset(label);
    }

    private void train() throws ClassifierException {
        ArrayList<Double> labels = new ArrayList<Double>(this.mgr.getLableIds());
        ArrayList<Task> tasks = new ArrayList<Task>();
        int index = 0;
        do {
            for (int count = 0; index < labels.size() && count < 700; ++index, ++count) {
                DatasetManager binaryMgr = this.getBinaryDatasetMgrForLabel((Double)labels.get(index));
                LogisticRegressionClassifier classifier = new LogisticRegressionClassifier(binaryMgr);
                this.classifiersMap.put((Double)labels.get(index), classifier);
                tasks.add(new Task(classifier));
            }
            Scheduler scheduler = new Scheduler();
            try {
                scheduler.schedluleTraining(tasks);
            }
            catch (InterruptedException e) {
                LOGGER.error(e.toString());
                Thread.currentThread().interrupt();
                throw new ClassifierException("Training thread interrupted ", e);
            }
            finally {
                scheduler.shutdown();
            }
        } while (index < labels.size());
    }

    @Override
    public void train(String datasetPath, SupportedFileFormats format, String specPath) throws DatasetException, DatasetFeatureDescriptorException, ClassifierException {
        if (format != SupportedFileFormats.CSV) {
            throw new DatasetException("Invalid Dataset format");
        }
        try {
            this.mgr.readDataset(datasetPath, specPath);
        }
        catch (IOException e) {
            throw new DatasetException("Invalid Dataset", e);
        }
        this.train();
    }

    @Override
    public void train(InputStream datasetSream, SupportedFileFormats format, InputStream specStream) throws DatasetException, DatasetFeatureDescriptorException, ClassifierException {
        if (format != SupportedFileFormats.CSV) {
            throw new DatasetException("Invalid Dataset format");
        }
        try {
            this.mgr.readDataset(datasetSream, specStream);
        }
        catch (IOException e) {
            throw new DatasetException("Invalid Dataset", e);
        }
        this.train();
    }

    @Override
    public String classify(List<?> featuresVector) throws ClassifierException, GenericTransformerException {
        String label = "";
        List<IClassification> classifications = this.getProbabilityDistribution(featuresVector, 1);
        if (!classifications.isEmpty()) {
            label = classifications.get(0).getLabel();
        }
        return label;
    }

    @Override
    public List<IClassification> classifyBulk(List<List<?>> featuresVectors) throws ClassifierException {
        return this.getProbabilityDistributionBulk(featuresVectors, 1);
    }

    @Override
    public List<IClassification> getProbabilityDistribution(List<?> featuresVector, int probNum) throws GenericTransformerException {
        List<Double> transformedVector = this.mgr.transformVector(featuresVector.stream().map(Object::toString).collect(Collectors.toList()));
        Stream<IClassification> classifications = this.classifiersMap.entrySet().stream().map(e -> {
            IClassification c = ((LogisticRegressionClassifier)e.getValue()).getProbabilityDistributionBinary(transformedVector);
            return new Classification(this.mgr.getLabelForIndex((Double)e.getKey()), c.getProbability());
        }).sorted(Comparator.comparing(IClassification::getProbability).reversed());
        if (probNum == -1) {
            return classifications.collect(Collectors.toList());
        }
        return classifications.limit(probNum).collect(Collectors.toList());
    }

    @Override
    public List<IClassification> getProbabilityDistributionBulk(List<List<?>> featuresVectors, int probNum) throws ClassifierException {
        try {
            List transformedVectors = featuresVectors.stream().map(f -> {
                try {
                    return this.mgr.transformVector(f.stream().map(Object::toString).collect(Collectors.toList()));
                }
                catch (GenericTransformerException e) {
                    throw new ClassifierWrapperException(e);
                }
            }).collect(Collectors.toList());
            return transformedVectors.stream().flatMap(t -> this.classifiersMap.entrySet().stream().map(e -> new Classification(this.mgr.getLabelForIndex((Double)e.getKey()), ((LogisticRegressionClassifier)e.getValue()).getProbabilityDistributionBinary((List<Double>)t).getProbability())).sorted(Comparator.comparing(IClassification::getProbability).reversed()).limit(probNum)).collect(Collectors.toList());
        }
        catch (ClassifierWrapperException e) {
            throw new ClassifierException(e.getCause());
        }
    }

    private void writeObject(ObjectOutputStream os) throws IOException {
        os.defaultWriteObject();
        HashMap<Double, List> modelsMap = new HashMap<Double, List>(this.classifiersMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((LogisticRegressionClassifier)e.getValue()).getModel())));
        os.writeObject(modelsMap);
    }

    private void readObject(ObjectInputStream os) throws IOException, ClassNotFoundException {
        os.defaultReadObject();
        HashMap modelsMap = (HashMap)os.readObject();
        this.classifiersMap = modelsMap.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> new LogisticRegressionClassifier((List)e.getValue())));
    }

    private static class ClassifierWrapperException
    extends RuntimeException {
        private static final long serialVersionUID = 897339537178989889L;

        ClassifierWrapperException(Exception cause) {
            super(cause);
        }
    }
}

