/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.forecasting;

import com.google.common.collect.Sets;
import com.ibm.bi.predict.algorithms.forecasting.ForecastingAlgorithm;
import com.ibm.bi.predict.algorithms.forecasting.ForecastingAlgorithmContext;
import com.ibm.bi.predict.algorithms.forecasting.exception.ForecastingParametersException;
import com.ibm.bi.predict.algorithms.forecasting.exception.ForecastingParametersExceptionKey;
import com.ibm.bi.predict.algorithms.forecasting.result.ForecastingResult;
import com.ibm.bi.predict.algorithms.forecasting.result.ForecastingResultData;
import com.ibm.bi.predict.algorithms.forecasting.result.ForecastingStatisticalDetails;
import com.ibm.bi.predict.algorithms.forecasting.result.SeriesResult;
import com.ibm.bi.predict.algorithms.forecasting.timedimension.TimeDimension;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.data.DataContext;
import com.ibm.bi.predict.data.DataPrep;
import com.ibm.bi.predict.data.StandardDataPrep;
import com.ibm.bi.predict.dataaccess.DataAccessProvider;
import com.ibm.bi.predict.exceptions.UserMessageException;
import com.ibm.bi.predict.forecasting.DataReader;
import com.ibm.bi.predict.forecasting.ForecastingContext;
import com.ibm.bi.predict.forecasting.TimeSeriesResult;
import com.ibm.bi.predict.graph.Tree;
import com.ibm.bi.predict.result.DataPrepResult;
import com.ibm.bi.predict.result.Message;
import com.ibm.bi.predict.result.MessageCode;
import com.ibm.bi.predict.result.StatusCode;
import com.ibm.bi.predict.service.DataProvider;
import com.ibm.bi.predict.service.PredictServiceFramework;
import com.ibm.bi.predict.service.PredictServiceRequest;
import com.ibm.bi.predict.service.PredictServiceResponse;
import com.ibm.bi.predict.source.ColumnGroup;
import com.ibm.bi.predict.source.DataColumnBuilder;
import com.ibm.bi.predict.source.DataSource;
import com.ibm.bi.predict.source.jsonstat.ColumnGroupType;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

public class ForecastingIntegration
implements PredictServiceFramework {
    private static final Logger LOG = PredictLoggerFactory.getLogger(ForecastingIntegration.class);
    protected static final int MINIMUM_SERIES_LENGTH = 5;
    private static final double MISSING_VALUE_THRESHOLD = 0.33;

    public static ForecastingIntegration getService() {
        return new ForecastingIntegration();
    }

    public PredictServiceResponse run(Collection<PredictServiceRequest> requests, Locale locale) {
        if (requests.isEmpty()) {
            return PredictServiceResponse.emptyResponse();
        }
        PredictServiceRequest request = requests.iterator().next();
        ForecastingContext context = (ForecastingContext)request.getContext();
        this.validateArguments(request.getData(), context);
        ForecastingResult result = this.runForecast(context, this.requestToDataColumns(request), this.requestToColumnGroups(request));
        return new TimeSeriesResult(request.getData(), result, context);
    }

    private List<DataColumn> requestToDataColumns(PredictServiceRequest request) {
        DataProvider dataProvider = request.getData();
        if (dataProvider.asDataSource() != null) {
            return DataColumnBuilder.build((DataSource)dataProvider.asDataSource());
        }
        ForecastingContext context = (ForecastingContext)request.getContext();
        return ((DataPrep)ForecastingIntegration.prepareData(dataProvider.asDataAccessProvider(), context).getContent()).driverColumns();
    }

    private List<ColumnGroup> requestToColumnGroups(PredictServiceRequest request) {
        DataProvider dataProvider = request.getData();
        if (dataProvider.asDataSource() != null) {
            return dataProvider.asDataSource().groups();
        }
        ForecastingContext context = (ForecastingContext)request.getContext();
        return context.getColumnGroups();
    }

    public ForecastingResult runAlgorithms(DataAccessProvider dataProvider, ForecastingContext context) {
        LOG.perfStart();
        LOG.perfLog("Starting execution of forecasting.");
        this.validateArguments(dataProvider, context);
        List columns = ((DataPrep)ForecastingIntegration.prepareData(dataProvider, context).getContent()).driverColumns();
        ForecastingResult algorithmResult = this.runForecast(context, columns, context.getColumnGroups());
        LOG.infoPerfLog("Finished execution of forecasting");
        LOG.perfStop();
        return algorithmResult;
    }

    private ForecastingResult runForecast(ForecastingContext context, List<DataColumn> columns, List<ColumnGroup> groups) {
        ArrayList<Message> errorMessages = new ArrayList<Message>();
        ArrayList<Message> warningMessages = new ArrayList<Message>();
        ForecastingIntegration.setTimeDimension(context, columns, groups, warningMessages);
        Map<String, SeriesResult> unfilteredSeriesMap = DataReader.makeSeries(columns, context, groups);
        Map<String, SeriesResult> seriesMap = this.filterMissingValues(unfilteredSeriesMap, errorMessages, warningMessages, context);
        if (seriesMap.isEmpty() || context.getTimeDimension().isSeriesTooShort()) {
            return ForecastingIntegration.makeErrorSeriesMap(context, errorMessages, warningMessages, unfilteredSeriesMap);
        }
        List<Tree<SeriesResult>> hierarchyTree = ForecastingIntegration.makeHierarchyTree(seriesMap);
        ForecastingAlgorithm algorithm = new ForecastingAlgorithm(hierarchyTree, (ForecastingAlgorithmContext)context);
        LOG.perfLog("Running forecasting algorithm");
        ForecastingResult algorithmResult = algorithm.run(context.getIntOpt("forecast.config.forecastPeriods"));
        LOG.perfLog("Finished running forecasting algorithm");
        if (StatusCode.SUCCESS == algorithmResult.getStatus()) {
            Map seriesResults = ((ForecastingResultData)algorithmResult.getContent()).getSeriesResults();
            boolean allSuccess = true;
            for (Map.Entry<String, SeriesResult> entry : unfilteredSeriesMap.entrySet()) {
                if (entry.getValue().isForecastSuccessful()) continue;
                allSuccess = false;
                seriesResults.put(entry.getKey(), entry.getValue().setStatisticalDetails(ForecastingIntegration.generateForecastingStatisticalDetails(entry.getValue(), context)));
            }
            StatusCode statusBasedOnMessages = allSuccess ? StatusCode.SUCCESS : StatusCode.PARTIAL_SUCCESS;
            algorithmResult = new ForecastingResult(statusBasedOnMessages, seriesResults, context.getTimeDimension());
        }
        algorithmResult.addErrorMessages(errorMessages);
        algorithmResult.addWarningMessages(warningMessages);
        hierarchyTree.forEach(v -> v.nodes().forEach(n -> LOG.info(() -> String.format("ID: %s %s", n.id(), ((SeriesResult)n.content()).toString()))));
        return algorithmResult;
    }

    private static ForecastingResult makeErrorSeriesMap(ForecastingContext context, List<Message> errorMessages, List<Message> warningMessages, Map<String, SeriesResult> unfilteredSeriesMap) {
        for (Map.Entry<String, SeriesResult> entry : unfilteredSeriesMap.entrySet()) {
            entry.getValue().setStatisticalDetails(ForecastingIntegration.generateForecastingStatisticalDetails(entry.getValue().setIsForecastSuccessful(false), context));
        }
        ForecastingResult result = new ForecastingResult(StatusCode.PARTIAL_SUCCESS, unfilteredSeriesMap, context.getTimeDimension());
        result.addErrorMessages(errorMessages);
        result.addWarningMessages(warningMessages);
        return result;
    }

    private static ForecastingStatisticalDetails generateForecastingStatisticalDetails(SeriesResult series, ForecastingContext context) {
        ForecastingStatisticalDetails details = new ForecastingStatisticalDetails(series, (ForecastingAlgorithmContext)context);
        Message errorMessage = series.getErrorMessage();
        if (errorMessage != null) {
            ForecastingIntegration.addErrorNote(details, errorMessage.getMessageCode(), context.getLocale("locale"));
            details.concatNotes();
        }
        return details;
    }

    private static void addErrorNote(ForecastingStatisticalDetails details, MessageCode messageCode, Locale locale) {
        if (MessageCode.isNotesMessage((MessageCode)messageCode)) {
            details.addNote(messageCode, locale, new Object[0]);
        }
    }

    private Map<String, SeriesResult> filterMissingValues(Map<String, SeriesResult> seriesMap, List<Message> errorMessages, List<Message> warningMessages, ForecastingContext context) {
        return seriesMap.entrySet().stream().filter(v -> {
            SeriesResult seriesResult = ((SeriesResult)v.getValue()).trimMissingValues();
            double missingValueCount = seriesResult.getHistoricalValuesMissingCount();
            int totalNumberOfHistoricValues = seriesResult.getHistoricalValues().length;
            double missingValueThreshold = context.getDouble("forecast.config.missingValueThreshold", 0.33);
            int minSeriesLength = context.getInt("forecast.config.minimumSeriesLength", 5);
            if (totalNumberOfHistoricValues < minSeriesLength) {
                warningMessages.add(new Message(MessageCode.FORECASTING_TOO_FEW_POINTS, context.getLocale("locale")));
                seriesResult.setErrorMessage(new Message(MessageCode.FORECASTING_NOTES_TOO_FEW_POINTS, context.getLocale("locale")));
                return false;
            }
            if (missingValueCount > 0.0 && totalNumberOfHistoricValues > 0) {
                if (missingValueCount > (double)totalNumberOfHistoricValues * missingValueThreshold) {
                    warningMessages.add(new Message(MessageCode.FORECASTING_MISSING_EXCEEDS_THRESHOLD, context.getLocale("locale")));
                    seriesResult.setErrorMessage(new Message(MessageCode.FORECASTING_NOTES_TOO_MANY_MISSING, context.getLocale("locale")));
                    return false;
                }
                warningMessages.add(new Message(MessageCode.FORECASTING_MISSING_BELOW_THRESHOLD, context.getLocale("locale")));
            }
            return true;
        }).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
    }

    private void validateArguments(DataProvider dataProvider, ForecastingContext context) {
        if (dataProvider == null || dataProvider.asDataAccessProvider() == null && dataProvider.asDataSource() == null) {
            LOG.error("Null data provider passed in");
            throw new ForecastingParametersException("Data provider cannot be null", ForecastingParametersExceptionKey.DATA_SOURCE);
        }
        this.validateContext(context);
    }

    private void validateArguments(DataAccessProvider dataProvider, ForecastingContext context) {
        if (dataProvider == null) {
            LOG.error("Null data provider passed in");
            throw new ForecastingParametersException("Data provider cannot be null", ForecastingParametersExceptionKey.DATA_SOURCE);
        }
        this.validateContext(context);
    }

    private void validateContext(ForecastingContext context) {
        if (context == null) {
            LOG.error("Null context passed in");
            throw new ForecastingParametersException("Context cannot be null", ForecastingParametersExceptionKey.CONTEXT);
        }
        this.validateConfidenceLevel(context);
        int maxPeriods = context.getInt("forecast.config.maximumForecastPeriods", 1500);
        this.validateTimePeriod(context, "forecast.config.forecastPeriods", "Forecasting periods", ForecastingParametersExceptionKey.TIME_PERIODS, 0, maxPeriods);
        this.validateTimePeriod(context, "forecast.config.seasonalityPeriod", "Seasonality periods", ForecastingParametersExceptionKey.SEASONALITY_PERIODS, 0, 10000);
        this.validateTimePeriod(context, "forecast.config.ignoreLastNPeriods", "Ignore periods", ForecastingParametersExceptionKey.IGNORE_PERIODS, 0, 100);
        List targetList = context.getIntList("target", new ArrayList());
        if (targetList.isEmpty()) {
            throw new ForecastingParametersException("Context should always have one or more targets", ForecastingParametersExceptionKey.TARGET);
        }
    }

    private void validateTimePeriod(ForecastingContext context, String key, String label, ForecastingParametersExceptionKey errorCode, int minValue, int maxValue) {
        context.getIntOpt(key).ifPresent(timePeriods -> {
            if (timePeriods < minValue || timePeriods > maxValue) {
                throw new ForecastingParametersException(label + " parameter should be between " + minValue + " and " + maxValue, errorCode);
            }
        });
    }

    private void validateConfidenceLevel(ForecastingContext context) {
        double confidenceLevel;
        Optional confidenceLevelOpt = context.getDoubleOpt("forecast.config.confidenceLevel");
        if (confidenceLevelOpt.isPresent() && ((confidenceLevel = ((Double)confidenceLevelOpt.get()).doubleValue()) < 0.5 || confidenceLevel >= 1.0)) {
            throw new ForecastingParametersException("Confidence level option was not a legal enumeration value", ForecastingParametersExceptionKey.CONFIDENCE_LEVEL);
        }
    }

    protected static void setTimeDimension(ForecastingContext context, List<DataColumn> dataColumns, List<ColumnGroup> groups) {
        ArrayList<Message> warningMessages = new ArrayList<Message>();
        ForecastingIntegration.setTimeDimension(context, dataColumns, groups, warningMessages);
    }

    protected static void setTimeDimension(ForecastingContext context, List<DataColumn> columns, List<ColumnGroup> groups, List<Message> warningMessages) {
        int[] timeIndices = ColumnGroup.columnsIndices(groups, (Set)Sets.newHashSet((Object[])new Enum[]{ColumnGroupType.TIME}));
        ArrayList<DataColumn> explodedCols = new ArrayList<DataColumn>();
        for (int i = 0; i < timeIndices.length; ++i) {
            DataColumn dataCol = columns.get(timeIndices[i]);
            if (dataCol.hasNestedDataColumns()) {
                explodedCols.addAll(dataCol.getNestedDataColumns());
                continue;
            }
            explodedCols.add(dataCol);
        }
        DataColumn[] timeColumns = (DataColumn[])explodedCols.stream().toArray(DataColumn[]::new);
        TimeDimension dimension = TimeDimension.makeForColumns((Locale)context.getLocale("locale"), (DataColumn[])timeColumns);
        ForecastingIntegration.reportInvalidTime(dimension, context, warningMessages);
        context.setTimeDimension(dimension);
        context.setCycleLength(dimension.findCycleLengthByTimeDelta());
    }

    private static void reportInvalidTime(TimeDimension dimension, ForecastingContext context, List<Message> warningMessages) {
        if (!dimension.getInvalidCategories().isEmpty()) {
            warningMessages.add(new Message(MessageCode.FORECASTING_INVALID_TIME_CATEGORIES, context.getLocale("locale")));
        }
    }

    private static List<Tree<SeriesResult>> makeHierarchyTree(Map<String, SeriesResult> seriesMap) {
        return Tree.buildForest(seriesMap.keySet(), s -> Collections.emptyList(), seriesMap::get);
    }

    private static DataPrepResult prepareData(DataAccessProvider provider, ForecastingContext context) {
        LOG.perfLog("Starting data preparation.");
        DataPrep preparation = StandardDataPrep.prepareOutliers((DataAccessProvider)provider, (DataContext)context);
        ForecastingIntegration.logDegenerateFields(preparation);
        DataPrepResult dataPrepResult = preparation.buildResult();
        if (ForecastingIntegration.dataPrepOutOfMemory(dataPrepResult)) {
            throw new UserMessageException(MessageCode.INSUFFICIENT_MEMORY);
        }
        LOG.perfLog("Finished data preparation");
        return dataPrepResult;
    }

    private static void logDegenerateFields(DataPrep prep) {
        if (!prep.degenerateColumns().isEmpty()) {
            LOG.debug(() -> String.format("Removed degenerate fields: %s", prep.degenerateColumns()));
        }
    }

    private static boolean dataPrepOutOfMemory(DataPrepResult dataPrepResult) {
        boolean outOfMemoryError = dataPrepResult.getErrorMessages().stream().anyMatch(error -> error.getMessageCode() == MessageCode.INSUFFICIENT_MEMORY);
        return dataPrepResult.getStatus() == StatusCode.FAILURE && outOfMemoryError;
    }
}

