/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.learningalgorithm;

import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.learningalgorithm.ICostFunc;
import com.ibm.smarts.generic.recommender.internal.raw.impl.classifiers.learningalgorithm.IGradientFunc;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class GradientDescent {
    private List<Double> thetas;
    private int maxIterations;
    private double startLearningRate;
    private IGradientFunc gradient;
    private ICostFunc costFunc;
    private double maxAcceptedLoss;
    private List<List<Double>> dataset;
    private static final double LEARNING_RATE_INCREASE_PERCENTAGE = 0.05;
    private static final double LEARNING_RATE_DECREASE_PERCENTAGE = 0.5;

    public GradientDescent(IGradientFunc func, ICostFunc costFunc, List<List<Double>> dataset, double maxAcceptedLoss, double startLearningRate, int maxIterations) {
        this.thetas = new ArrayList<Double>(Collections.nCopies(dataset.get(0).size(), 0.0));
        this.gradient = func;
        this.dataset = dataset;
        this.costFunc = costFunc;
        this.maxAcceptedLoss = maxAcceptedLoss;
        this.startLearningRate = startLearningRate;
        this.maxIterations = maxIterations;
    }

    private List<Double> calcuateThetas(double learningRate) {
        ArrayList<Double> tempThetas = new ArrayList<Double>(Collections.nCopies(this.dataset.get(0).size(), 0.0));
        for (int i = 1; i < this.thetas.size(); ++i) {
            double theta = this.thetas.get(i) - learningRate * this.gradient.calc(this.thetas, this.dataset, i);
            tempThetas.set(i, theta);
        }
        return tempThetas;
    }

    public List<Double> minimize() {
        boolean converged = false;
        int iterations = 0;
        double cost = Double.POSITIVE_INFINITY;
        double learningRate = this.startLearningRate;
        double fallbackCost = 0.0;
        do {
            List<Double> tempThetas;
            double nCost;
            if ((nCost = this.costFunc.calc(tempThetas = this.calcuateThetas(learningRate), this.dataset)) <= cost) {
                learningRate += learningRate * 0.05;
                fallbackCost = cost;
                cost = nCost;
                this.thetas = tempThetas;
                continue;
            }
            learningRate -= learningRate * 0.5;
            cost = fallbackCost;
        } while (!(converged = ++iterations >= this.maxIterations || cost <= this.maxAcceptedLoss));
        return this.thetas;
    }
}

