/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.validation.Bag;
import smile.validation.ClassificationMetrics;
import smile.validation.ClassificationValidations;
import smile.validation.metric.ConfusionMatrix;

public class ClassificationValidation<M>
implements Serializable {
    private static final long serialVersionUID = 2L;
    public final M model;
    public final int[] truth;
    public final int[] prediction;
    public final double[][] posteriori;
    public final ConfusionMatrix confusion;
    public final ClassificationMetrics metrics;

    public ClassificationValidation(M model, double fitTime, double scoreTime, int[] truth, int[] prediction) {
        this.model = model;
        this.truth = truth;
        this.prediction = prediction;
        this.posteriori = null;
        this.confusion = ConfusionMatrix.of(truth, prediction);
        this.metrics = ClassificationMetrics.of(fitTime, scoreTime, truth, prediction);
    }

    public ClassificationValidation(M model, double fitTime, double scoreTime, int[] truth, int[] prediction, double[][] posteriori) {
        this.model = model;
        this.truth = truth;
        this.prediction = prediction;
        this.posteriori = posteriori;
        this.confusion = ConfusionMatrix.of(truth, prediction);
        this.metrics = ClassificationMetrics.of(fitTime, scoreTime, truth, prediction, posteriori);
    }

    public String toString() {
        return this.metrics.toString();
    }

    public static <T, M extends Classifier<T>> ClassificationValidation<M> of(T[] x, int[] y, T[] testx, int[] testy, BiFunction<T[], int[], M> trainer) {
        long start = System.nanoTime();
        Classifier model = (Classifier)trainer.apply((T[][])x, y);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        start = System.nanoTime();
        if (model.soft()) {
            int k = model.numClasses();
            double[][] posteriori = new double[testx.length][k];
            int[] prediction = model.predict(testx, posteriori);
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return new ClassificationValidation<Classifier>(model, fitTime, scoreTime, testy, prediction, posteriori);
        }
        int[] prediction = model.predict(testx);
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return new ClassificationValidation<Classifier>(model, fitTime, scoreTime, testy, prediction);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> of(Bag[] bags, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            T[] trainx = MathEx.slice(x, bag.samples);
            int[] trainy = MathEx.slice(y, bag.samples);
            T[] testx = MathEx.slice(x, bag.oob);
            int[] testy = MathEx.slice(y, bag.oob);
            rounds.add(ClassificationValidation.of(trainx, trainy, testx, testy, trainer));
        }
        return new ClassificationValidations(rounds);
    }

    public static <M extends DataFrameClassifier> ClassificationValidation<M> of(Formula formula, DataFrame train, DataFrame test, BiFunction<Formula, DataFrame, M> trainer) {
        int[] y = formula.y(train).toIntArray();
        int[] testy = formula.y(test).toIntArray();
        long start = System.nanoTime();
        DataFrameClassifier model = (DataFrameClassifier)trainer.apply(formula, train);
        double fitTime = (double)(System.nanoTime() - start) / 1000000.0;
        int n = test.nrow();
        int[] prediction = new int[n];
        if (model.soft()) {
            int k = model.numClasses();
            double[][] posteriori = new double[n][k];
            start = System.nanoTime();
            for (int i = 0; i < n; ++i) {
                prediction[i] = model.predict((Tuple)test.get(i), posteriori[i]);
            }
            double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
            return new ClassificationValidation<DataFrameClassifier>(model, fitTime, scoreTime, testy, prediction, posteriori);
        }
        start = System.nanoTime();
        for (int i = 0; i < n; ++i) {
            prediction[i] = model.predict((Tuple)test.get(i));
        }
        double scoreTime = (double)(System.nanoTime() - start) / 1000000.0;
        return new ClassificationValidation<DataFrameClassifier>(model, fitTime, scoreTime, testy, prediction);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> of(Bag[] bags, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        ArrayList rounds = new ArrayList(bags.length);
        for (Bag bag : bags) {
            rounds.add(ClassificationValidation.of(formula, data.of(bag.samples), data.of(bag.oob), trainer));
        }
        return new ClassificationValidations(rounds);
    }
}

