/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.classification.Classifier;
import smile.data.CategoricalEncoder;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;

public interface DataFrameClassifier
extends Classifier<Tuple> {
    public Formula formula();

    public StructType schema();

    default public int[] predict(DataFrame data) {
        this.formula().bind(data.schema());
        return data.stream().mapToInt(this::predict).toArray();
    }

    default public int[] predict(DataFrame data, List<double[]> posteriori) {
        this.formula().bind(data.schema());
        int n = data.size();
        int k = this.numClasses();
        double[][] prob = new double[n][k];
        Collections.addAll(posteriori, prob);
        return IntStream.range(0, n).parallel().map(i -> this.predict((Tuple)data.get(i), prob[i])).toArray();
    }

    public static DataFrameClassifier of(final Formula formula, DataFrame data, Properties params, Classifier.Trainer<double[], ?> trainer) {
        DataFrame X = formula.x(data);
        final StructType schema = X.schema();
        double[][] x = X.toArray(false, CategoricalEncoder.DUMMY, new String[0]);
        int[] y = formula.y(data).toIntArray();
        Object model = trainer.fit((T[])x, y, params);
        return new DataFrameClassifier(){
            final /* synthetic */ Classifier val$model;
            {
                this.val$model = classifier;
            }

            @Override
            public Formula formula() {
                return formula;
            }

            @Override
            public StructType schema() {
                return schema;
            }

            @Override
            public int numClasses() {
                return this.val$model.numClasses();
            }

            @Override
            public int[] classes() {
                return this.val$model.classes();
            }

            @Override
            public int predict(Tuple x) {
                return this.val$model.predict(formula.x(x).toArray(new String[0]));
            }

            @Override
            public int predict(Tuple x, double[] posteriori) {
                return this.val$model.predict(formula.x(x).toArray(new String[0]), posteriori);
            }
        };
    }

    public static DataFrameClassifier ensemble(final DataFrameClassifier ... models) {
        return new DataFrameClassifier(){
            private final boolean soft;
            private final boolean online;
            {
                this.soft = Arrays.stream(models).allMatch(Classifier::soft);
                this.online = Arrays.stream(models).allMatch(Classifier::online);
            }

            @Override
            public boolean soft() {
                return this.soft;
            }

            @Override
            public boolean online() {
                return this.online;
            }

            @Override
            public int numClasses() {
                return models[0].numClasses();
            }

            @Override
            public int[] classes() {
                return models[0].classes();
            }

            @Override
            public Formula formula() {
                return models[0].formula();
            }

            @Override
            public StructType schema() {
                return models[0].schema();
            }

            @Override
            public int predict(Tuple x) {
                int[] labels = new int[models.length];
                for (int i = 0; i < models.length; ++i) {
                    labels[i] = models[i].predict(x);
                }
                return MathEx.mode(labels);
            }

            @Override
            public int predict(Tuple x, double[] posteriori) {
                Arrays.fill(posteriori, 0.0);
                double[] prob = new double[posteriori.length];
                for (DataFrameClassifier model : models) {
                    model.predict(x, prob);
                    for (int i = 0; i < prob.length; ++i) {
                        int n = i;
                        posteriori[n] = posteriori[n] + prob[i];
                    }
                }
                int i = 0;
                while (i < posteriori.length) {
                    int n = i++;
                    posteriori[n] = posteriori[n] / (double)models.length;
                }
                return MathEx.whichMax(posteriori);
            }

            @Override
            public void update(Tuple x, int y) {
                for (DataFrameClassifier model : models) {
                    model.update(x, y);
                }
            }
        };
    }

    public static interface Trainer<M extends DataFrameClassifier> {
        default public M fit(Formula formula, DataFrame data) {
            Properties params = new Properties();
            return this.fit(formula, data, params);
        }

        public M fit(Formula var1, DataFrame var2, Properties var3);
    }
}

