/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.math;

import com.ibm.bi.predict.math.OptimizationError;
import com.ibm.bi.predict.utils.Tuple;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.function.ToDoubleFunction;
import java.util.stream.IntStream;
import org.apache.commons.math3.exception.MaxCountExceededException;
import org.apache.commons.math3.optim.ConvergenceChecker;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleBounds;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.MultivariateOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.BOBYQAOptimizer;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.CMAESOptimizer;
import org.apache.commons.math3.optim.univariate.BrentOptimizer;
import org.apache.commons.math3.optim.univariate.SearchInterval;
import org.apache.commons.math3.optim.univariate.UnivariateObjectiveFunction;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.random.Well19937c;

public class Optimizer<T> {
    private static final int MAX_OPTIMIZER_WAIT = 30000;
    private final ToDoubleFunction<T> scoreFunction;
    private final Function<double[], T> makeFunction;
    private final GoalType goal;
    private MultivariateMethod multiMethod;
    private double requiredPrecision = 0.001;
    private double[] loBounds;
    private double[] hiBounds;
    private int evaluations;
    private int iterations;

    public static <S> Optimizer<S> makeMinimizer(Function<double[], S> makeFunction, ToDoubleFunction<S> scoreFunction) {
        return new Optimizer<S>(scoreFunction, makeFunction, GoalType.MINIMIZE);
    }

    public static <S> Optimizer<S> makeMaximizer(Function<double[], S> makeFunction, ToDoubleFunction<S> scoreFunction) {
        return new Optimizer<S>(scoreFunction, makeFunction, GoalType.MAXIMIZE);
    }

    public Optimizer(ToDoubleFunction<T> scoreFunction, Function<double[], T> makeFunction, GoalType goal) {
        this.scoreFunction = scoreFunction;
        this.makeFunction = makeFunction;
        this.goal = goal;
        this.multiMethod = MultivariateMethod.CMAES;
    }

    public Optimizer<T> bounds(double[] loBounds, double[] hiBounds) {
        this.loBounds = loBounds;
        this.hiBounds = hiBounds;
        if (loBounds == null || hiBounds == null) {
            throw new IllegalStateException("Bounds must not be null");
        }
        if (loBounds.length != hiBounds.length) {
            throw new IllegalStateException("Upper and lower bounds must match in length");
        }
        if (loBounds.length < 1) {
            throw new IllegalStateException("Must have at least one parameter to optimize over");
        }
        return this;
    }

    public int evaluations() {
        return this.evaluations;
    }

    public int iterations() {
        return this.iterations;
    }

    public Optimizer<T> requiredPrecision(double requiredPrecision) {
        this.requiredPrecision = requiredPrecision;
        return this;
    }

    public Optimizer<T> multivariateMethod(MultivariateMethod method) {
        this.multiMethod = method;
        return this;
    }

    public T run() throws OptimizationError {
        Object result = null;
        if (this.loBounds == null || this.hiBounds == null) {
            throw new IllegalStateException("Must set lower and upper bounds before optimizing");
        }
        ExecutorService executor = Executors.newSingleThreadExecutor();
        Future<Object> future = executor.submit(() -> {
            if (this.loBounds.length == 1) {
                return this.optimizeUnivariate();
            }
            return this.optimizeMultivariate();
        });
        try {
            result = future.get(30000L, TimeUnit.MILLISECONDS);
        }
        catch (TimeoutException te) {
            throw new OptimizationError("Timed out waiting for optimizer", te);
        }
        catch (InterruptedException e) {
            if (Thread.interrupted()) {
                Thread.currentThread().interrupt();
                throw new OptimizationError("Interrupted", e);
            }
        }
        catch (ExecutionException e) {
            if (e.getCause() instanceof OptimizationError) {
                throw (OptimizationError)e.getCause();
            }
            throw new OptimizationError("Execution Exception", e);
        }
        finally {
            executor.shutdownNow();
        }
        return (T)result;
    }

    private T optimizeMultivariate() throws OptimizationError {
        double[] initial = IntStream.range(0, this.loBounds.length).mapToDouble(i -> (this.loBounds[i] + this.hiBounds[i]) / 2.0).toArray();
        try {
            Tuple<MultivariateOptimizer, PointValuePair> result = this.multiMethod == MultivariateMethod.BOBYQA ? this.makeBOBYQAOptimizer(initial) : this.makeCMAESOptimizer(initial);
            this.evaluations = ((MultivariateOptimizer)result._1).getEvaluations();
            this.iterations = ((MultivariateOptimizer)result._1).getIterations();
            return this.makeFunction.apply(((PointValuePair)result._2).getPoint());
        }
        catch (MaxCountExceededException e) {
            throw new OptimizationError("Multivariate optimizer could not optimize within required number of iterations", e);
        }
        catch (IllegalArgumentException e) {
            throw new OptimizationError("Bad parameter set in the multivariate optimizer", e);
        }
        catch (Exception e) {
            throw new OptimizationError("Multivariate optimizer experienced an unknown error", e);
        }
    }

    private Tuple<MultivariateOptimizer, PointValuePair> makeBOBYQAOptimizer(double[] initial) {
        double diam2 = IntStream.range(0, this.loBounds.length).mapToDouble(i -> this.hiBounds[i] - this.loBounds[i]).map(x -> x * x).sum();
        int interpolationPoints = 2 * initial.length + 1;
        BOBYQAOptimizer optimizer = new BOBYQAOptimizer(interpolationPoints, Math.sqrt(diam2) / 2.0, this.requiredPrecision);
        PointValuePair optimum = optimizer.optimize(new OptimizationData[]{new ObjectiveFunction(array -> this.scoreFunction.applyAsDouble(this.makeFunction.apply(array))), new MaxEval(1000), new InitialGuess(initial), new SimpleBounds(this.loBounds, this.hiBounds), this.goal});
        return new Tuple<BOBYQAOptimizer, PointValuePair>(optimizer, optimum);
    }

    private Tuple<MultivariateOptimizer, PointValuePair> makeCMAESOptimizer(double[] initial) {
        int N = this.loBounds.length;
        double[] sigma = IntStream.range(0, N).mapToDouble(i -> (this.hiBounds[i] - this.loBounds[i]) / 2.0).toArray();
        int population = (int)(4L + Math.round(3.0 * Math.log(N)));
        CMAESOptimizer optimizer = new CMAESOptimizer(1000, 0.0, true, 0, 0, (RandomGenerator)new Well19937c(13), false, (ConvergenceChecker)new SimpleValueChecker(0.0, this.requiredPrecision));
        PointValuePair optimum = optimizer.optimize(new OptimizationData[]{new ObjectiveFunction(array -> this.scoreFunction.applyAsDouble(this.makeFunction.apply(array))), new MaxEval(1000), new InitialGuess(initial), new SimpleBounds(this.loBounds, this.hiBounds), new CMAESOptimizer.PopulationSize(population), new CMAESOptimizer.Sigma(sigma), this.goal});
        return new Tuple<CMAESOptimizer, PointValuePair>(optimizer, optimum);
    }

    private T optimizeUnivariate() throws OptimizationError {
        BrentOptimizer optimizer = new BrentOptimizer(1.0E-10, this.requiredPrecision / 2.0);
        try {
            UnivariatePointValuePair optimum = optimizer.optimize(new OptimizationData[]{new UnivariateObjectiveFunction(v -> this.scoreFunction.applyAsDouble(this.makeFunction.apply(new double[]{v}))), new MaxEval(100), new SearchInterval(this.loBounds[0], this.hiBounds[0]), this.goal});
            this.evaluations = optimizer.getEvaluations();
            this.iterations = optimizer.getIterations();
            return this.makeFunction.apply(new double[]{optimum.getPoint()});
        }
        catch (MaxCountExceededException e) {
            throw new OptimizationError("Univariate optimizer could not optimize within required number of iterations", e);
        }
        catch (IllegalArgumentException e) {
            throw new OptimizationError("Bad parameter set in the univariate optimizer", e);
        }
        catch (Exception e) {
            throw new OptimizationError("Univariate optimizer experienced an unknown error", e);
        }
    }

    public static enum MultivariateMethod {
        BOBYQA,
        CMAES;

    }
}

