/*
 * Decompiled with CFR 0.152.
 */
package com.ibm.bi.predict.data;

import com.ibm.bi.predict.data.Category;
import com.ibm.bi.predict.data.ColumnDataBuilder;
import com.ibm.bi.predict.data.DataColumn;
import com.ibm.bi.predict.data.DataContext;
import com.ibm.bi.predict.data.DataFrame;
import com.ibm.bi.predict.data.store.DataArray;
import com.ibm.bi.predict.dataaccess.DataAccessProvider;
import com.ibm.bi.predict.dataaccess.MetaData;
import com.ibm.bi.predict.dataaccess.types.FieldType;
import com.ibm.bi.predict.exceptions.BadParametersException;
import com.ibm.bi.predict.result.DataPrepResult;
import com.ibm.bi.predict.result.Message;
import com.ibm.bi.predict.result.MessageCode;
import com.ibm.bi.predict.result.StatusCode;
import com.ibm.bi.predict.utils.Logger;
import com.ibm.bi.predict.utils.PredictLoggerFactory;
import com.ibm.bi.predict.utils.Tuple;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiPredicate;
import java.util.function.Predicate;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class DataPrep {
    DataAccessProvider provider;
    final boolean hasTarget;
    private DataColumn target;
    private List<DataColumn> drivers;
    private List<DataColumn> degenerate = new ArrayList<DataColumn>();
    private List<Message> warningMessages = new ArrayList<Message>();
    private List<Message> errorMessages = new ArrayList<Message>();
    private StatusCode status = StatusCode.SUCCESS;
    private DataContext context;
    private Optional<Integer> targetIndex;
    private int[] driversIndexes;
    private static final Logger LOG = PredictLoggerFactory.getLogger(DataPrep.class);

    public DataPrep(DataAccessProvider provider, DataContext context, Optional<Integer> targetIndex, int[] driversIndexes) {
        this.provider = provider;
        this.hasTarget = targetIndex.isPresent();
        this.context = context;
        this.targetIndex = targetIndex;
        this.driversIndexes = driversIndexes;
        MetaData meta = provider.getMetaData();
        int fieldCount = meta.fieldCount();
        this.targetIndex.ifPresent(targetValue -> {
            if (targetValue < 0 || targetValue >= fieldCount) {
                throw new BadParametersException("Invalid target index");
            }
        });
        this.checkDriverIndices(fieldCount);
    }

    public static DataFrame makeWithoutPreparation(DataAccessProvider provider, DataContext context, Optional<Integer> targetIndex, int[] driversIndexes) {
        return new DataPrep(provider, context, targetIndex, driversIndexes).makeFrame();
    }

    private void checkDriverIndices(int fieldCount) {
        HashSet<Integer> observed = new HashSet<Integer>();
        for (int idx : this.driversIndexes) {
            if (idx < 0 || idx > fieldCount) {
                throw new BadParametersException("Invalid factor index");
            }
            if (!observed.add(idx)) {
                throw new BadParametersException("Duplicate potential drivers included in parameters");
            }
            if (idx != this.targetIndex.orElse(-1)) continue;
            throw new BadParametersException("Target included in potential drivers");
        }
    }

    public DataPrepResult buildResult() {
        Message message;
        DataPrepResult dataPrepResult = new DataPrepResult(this.status, this);
        List<Message> warnings = this.getWarningMessages();
        List<Message> errors = this.getErrorMessages();
        dataPrepResult.addErrorMessages(errors);
        dataPrepResult.addWarningMessages(warnings);
        if (dataPrepResult.getStatus().equals((Object)StatusCode.FAILURE)) {
            return dataPrepResult;
        }
        if (this.hasTarget && this.targetColumn().rowCount() < 1) {
            message = new Message(MessageCode.ALL_MISSING, this.context.getLocale("locale"));
            warnings.add(message);
        }
        if (this.hasTarget && this.targetColumn().getStatus().contains((Object)DataColumn.Status.ZERO_INFLATED)) {
            ArrayList<String> columnIds = new ArrayList<String>();
            columnIds.add(this.targetColumn().getId());
            Message message2 = new Message(MessageCode.TARGET_TRANS_ZERO, columnIds, this.context.getLocale("locale"), new Object[0]);
            warnings.add(message2);
        }
        if (this.context.getBoolean("sampledData", false)) {
            message = new Message(MessageCode.SAMPLED_DATA, this.context.getLocale("locale"));
            warnings.add(message);
        }
        dataPrepResult.addWarningMessages(warnings);
        return dataPrepResult;
    }

    public List<DataColumn> driverColumns() {
        this.ensureDataRead();
        return this.drivers;
    }

    public List<DataColumn> degenerateColumns() {
        return this.degenerate;
    }

    public DataFrame makeFrame() {
        this.ensureDataRead();
        ArrayList names = new ArrayList();
        ArrayList<DataColumn.ColumnMeta> metas = new ArrayList<DataColumn.ColumnMeta>();
        ArrayList<Integer> cats = new ArrayList<Integer>();
        ArrayList<DataColumn> columns = new ArrayList<DataColumn>();
        if (this.hasTarget) {
            columns.add(this.target);
        }
        columns.addAll(this.drivers);
        columns.forEach(c -> {
            names.add(c.getMeta().getId());
            metas.add(c.getMeta());
            cats.add(c.getCategoryCount());
        });
        Object[] columnArray = (DataArray[])columns.stream().map(DataColumn::getDataArray).toArray(DataArray[]::new);
        if (LOG.isTraceEnabled()) {
            LOG.trace(String.format("Creating DataFrame:%n%n%s%n%s", names, Arrays.deepToString(columnArray)));
        }
        return new DataFrame(metas, cats, (DataArray[])columnArray, this.hasTarget);
    }

    public List<DataColumn.ColumnMeta> getColumnMeta() {
        ArrayList<DataColumn> columns = new ArrayList<DataColumn>();
        if (this.hasTarget) {
            columns.add(this.target);
        }
        columns.addAll(this.drivers);
        return columns.stream().map(dc -> dc.getMeta()).collect(Collectors.toList());
    }

    public DataPrep readData(BiPredicate<Integer, MetaData> columnFilter, BiPredicate<Double, FieldType> targetValueFilter) {
        MetaData metaData = this.provider.getMetaData();
        this.logFieldTypes(metaData);
        IntStream filteredDrivers = IntStream.of(this.driversIndexes).filter(i -> columnFilter.test(i, metaData));
        int[] desiredColumns = this.hasTarget ? IntStream.concat(IntStream.of((int)this.targetIndex.get()), filteredDrivers).toArray() : filteredDrivers.toArray();
        ColumnDataBuilder columnDataBuilder = new ColumnDataBuilder(this.provider, desiredColumns, this.targetIndex, targetValueFilter, this.context.getInt("allocationTimeout", 10000));
        this.assembleDataColumns(metaData, desiredColumns, columnDataBuilder);
        this.releaseProvider();
        return this;
    }

    private void assembleDataColumns(MetaData metaData, int[] desiredColumns, ColumnDataBuilder columnDataBuilder) {
        this.drivers = new ArrayList<DataColumn>(desiredColumns.length);
        for (int c = 0; c < desiredColumns.length; ++c) {
            int index = desiredColumns[c];
            DataColumn.ColumnMeta columnMeta = new DataColumn.ColumnMeta(metaData.getFieldIdentifier(index), metaData.getFieldDisplayLabel(index), metaData.getFieldType(index), metaData.getFieldAggregation(index), metaData.getConcepts(index), metaData.getFieldIdentifiers(index), metaData.getFieldDisplayLabels(index), metaData.getUniqueIdentifiersMap(index), metaData.getConceptsMap());
            List<Category> categories = columnDataBuilder.buildCategories(c);
            Optional<double[]> columnData = columnDataBuilder.buildData(c);
            if (columnData.isPresent()) {
                DataColumn col = new DataColumn(columnMeta, categories, columnData.get(), c);
                if (col.hasNestedDataColumns()) {
                    Tuple<List<Set<Category>>, List<DataArray>> nestedInfo = columnDataBuilder.buildNestedCategories(c);
                    List<DataArray> adjustedNested = this.adjustDataArray((List)nestedInfo._2, columnData.get());
                    col.setNestedFields((Tuple<List<Set<Category>>, List<DataArray>>)Tuple.of((Object)nestedInfo._1, adjustedNested));
                }
                if (this.hasTarget && this.targetIndex.isPresent() && this.targetIndex.get() == index) {
                    this.target = col;
                    continue;
                }
                this.drivers.add(col);
                continue;
            }
            this.errorMessages.add(new Message(MessageCode.INSUFFICIENT_MEMORY, this.context.getLocale("locale")));
            this.status = StatusCode.FAILURE;
            break;
        }
    }

    private List<DataArray> adjustDataArray(List<DataArray> nestedDataArrays, double[] parentData) {
        if (parentData.length == nestedDataArrays.get(0).size()) {
            return nestedDataArrays;
        }
        return nestedDataArrays.stream().map(v -> {
            double[] newData = new double[parentData.length];
            for (int j = 0; j < parentData.length; ++j) {
                newData[j] = v.value((int)parentData[j]);
            }
            return DataArray.of(newData);
        }).collect(Collectors.toList());
    }

    private void releaseProvider() {
        if (this.context.getBoolean("releaseProvider", true)) {
            this.provider.release();
            this.provider = null;
        }
    }

    private void logFieldTypes(MetaData metaData) {
        for (int i = 0; i < metaData.fieldCount(); ++i) {
            LOG.debug("Index: {}, Field: {}, type: {}, number of categories: {}", new Object[]{i, metaData.getFieldIdentifier(i), metaData.getFieldType(i), metaData.getFieldType(i) == FieldType.CATEGORICAL ? Integer.valueOf(metaData.getFieldCategories(i)) : "n/a"});
        }
    }

    public DataPrep replaceDrivers(UnaryOperator<DataColumn> columnTransform) {
        this.ensureDataRead();
        if (!this.drivers.isEmpty()) {
            int olength = this.drivers.get(0).rowCount();
            try {
                this.drivers = this.drivers.stream().map(columnTransform).filter(Objects::nonNull).collect(Collectors.toList());
            }
            catch (OutOfMemoryError e) {
                this.status = StatusCode.FAILURE;
                this.errorMessages.add(new Message(MessageCode.INSUFFICIENT_MEMORY, this.context.getLocale("locale")));
            }
            for (DataColumn driver : this.drivers) {
                if (driver.rowCount() == olength) continue;
                throw new IllegalStateException("Column transform changed driver column length for column " + driver.getId());
            }
        }
        return this;
    }

    public DataPrep replaceTarget(UnaryOperator<DataColumn> columnTransform) {
        if (this.hasTarget) {
            this.ensureDataRead();
            int oldRowCount = this.target.rowCount();
            this.target = (DataColumn)columnTransform.apply(this.target);
            if (oldRowCount != this.target.rowCount()) {
                throw new IllegalStateException("Column transform changed target column length");
            }
        }
        return this;
    }

    public DataPrep removeDegenerate() {
        Predicate<DataColumn> degenerateFilter = c -> !this.isDegenerate((DataColumn)c);
        this.retainDrivers(degenerateFilter);
        return this;
    }

    public DataPrep retainDrivers(Predicate<DataColumn> columnFilter) {
        this.ensureDataRead();
        this.drivers = this.drivers.stream().filter(columnFilter).collect(Collectors.toList());
        return this;
    }

    public boolean isDegenerate(DataColumn column) {
        if (column.hasStatus(DataColumn.Status.DEGENERATE)) {
            this.degenerate.add(column);
            return true;
        }
        return false;
    }

    public DataColumn targetColumn() {
        this.ensureDataRead();
        return this.target;
    }

    private void ensureDataRead() {
        if (this.drivers == null) {
            this.readData((i, meta) -> true, (v, type) -> true);
        }
    }

    public void addWarningMessage(Message message) {
        this.warningMessages.add(message);
    }

    public List<Message> getWarningMessages() {
        return this.warningMessages;
    }

    public void addErrorMessage(Message message) {
        this.errorMessages.add(message);
    }

    public List<Message> getErrorMessages() {
        return this.errorMessages;
    }

    public void setStatus(StatusCode status) {
        this.status = status;
    }

    public StatusCode getStatus() {
        return this.status;
    }
}

