/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.Arrays;
import java.util.function.BiFunction;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.stat.Sampling;
import smile.validation.Bag;
import smile.validation.ClassificationValidation;
import smile.validation.ClassificationValidations;
import smile.validation.RegressionValidation;
import smile.validation.RegressionValidations;

public interface Bootstrap {
    public static Bag[] of(int n, int k) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of bootstrap: " + k);
        }
        Bag[] bags = new Bag[k];
        for (int j = 0; j < k; ++j) {
            boolean[] hit = new boolean[n];
            int hits = 0;
            int[] train = new int[n];
            for (int i = 0; i < n; ++i) {
                int r;
                train[i] = r = MathEx.randomInt(n);
                if (hit[r]) continue;
                ++hits;
                hit[r] = true;
            }
            int[] test = new int[n - hits];
            int p = 0;
            for (int i = 0; i < n; ++i) {
                if (hit[i]) continue;
                test[p++] = i;
            }
            bags[j] = new Bag(train, test);
        }
        return bags;
    }

    public static Bag[] of(int[] category, int k) {
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of bootstrap: " + k);
        }
        int n = category.length;
        boolean[] hit = new boolean[n];
        Bag[] bags = new Bag[k];
        for (int round = 0; round < k; ++round) {
            int[] train = Sampling.stratify(category, 1.0);
            int hits = 0;
            Arrays.fill(hit, false);
            for (int i : train) {
                if (hit[i]) continue;
                ++hits;
                hit[i] = true;
            }
            int[] test = new int[n - hits];
            int p = 0;
            for (int i = 0; i < n; ++i) {
                if (hit[i]) continue;
                test[p++] = i;
            }
            bags[round] = new Bag(train, test);
        }
        return bags;
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> classification(int k, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        return ClassificationValidation.of(Bootstrap.of(x.length, k), x, y, trainer);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> classification(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        return ClassificationValidation.of(Bootstrap.of(data.size(), k), formula, data, trainer);
    }

    public static <T, M extends Regression<T>> RegressionValidations<M> regression(int k, T[] x, double[] y, BiFunction<T[], double[], M> trainer) {
        return RegressionValidation.of(Bootstrap.of(x.length, k), x, y, trainer);
    }

    public static <M extends DataFrameRegression> RegressionValidations<M> regression(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        return RegressionValidation.of(Bootstrap.of(data.size(), k), formula, data, trainer);
    }
}

