/*
 * Decompiled with CFR 0.152.
 */
package smile.validation;

import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.math.MathEx;
import smile.regression.DataFrameRegression;
import smile.regression.Regression;
import smile.sort.QuickSort;
import smile.stat.Sampling;
import smile.util.IntSet;
import smile.validation.Bag;
import smile.validation.ClassificationValidation;
import smile.validation.ClassificationValidations;
import smile.validation.RegressionValidation;
import smile.validation.RegressionValidations;

public interface CrossValidation {
    public static Bag[] of(int n, int k) {
        if (n < 0) {
            throw new IllegalArgumentException("Invalid sample size: " + n);
        }
        if (k < 0 || k > n) {
            throw new IllegalArgumentException("Invalid number of CV rounds: " + k);
        }
        Bag[] bags = new Bag[k];
        int[] index = MathEx.permutate(n);
        int chunk = n / k;
        for (int i = 0; i < k; ++i) {
            int start = chunk * i;
            int end = chunk * (i + 1);
            if (i == k - 1) {
                end = n;
            }
            int[] train = new int[n - end + start];
            int[] test = new int[end - start];
            int p = 0;
            int q = 0;
            for (int j = 0; j < n; ++j) {
                if (j >= start && j < end) {
                    test[p++] = index[j];
                    continue;
                }
                train[q++] = index[j];
            }
            bags[i] = new Bag(train, test);
        }
        return bags;
    }

    public static Bag[] stratify(int[] category, int k) {
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + k);
        }
        int[][] strata = Sampling.strata(category);
        int min = Arrays.stream(strata).mapToInt(stratum -> ((int[])stratum).length).min().getAsInt();
        if (min < k) {
            Logger logger = LoggerFactory.getLogger(CrossValidation.class);
            logger.warn("The least populated class has only {} members, which is less than k={}.", (Object)min, (Object)k);
        }
        int n = category.length;
        int m = strata.length;
        for (int[] stratum2 : strata) {
            MathEx.permutate(stratum2);
        }
        int[] chunk = new int[m];
        for (int i = 0; i < m; ++i) {
            chunk[i] = Math.max(1, strata[i].length / k);
        }
        Bag[] bags = new Bag[k];
        for (int i = 0; i < k; ++i) {
            int p = 0;
            int q = 0;
            int[] train = new int[n];
            int[] test = new int[n];
            for (int j = 0; j < m; ++j) {
                int size = strata[j].length;
                int start = chunk[j] * i;
                int end = chunk[j] * (i + 1);
                if (i == k - 1) {
                    end = size;
                }
                int[] stratum3 = strata[j];
                for (int l = 0; l < size; ++l) {
                    if (l >= start && l < end) {
                        test[q++] = stratum3[l];
                        continue;
                    }
                    train[p++] = stratum3[l];
                }
            }
            train = Arrays.copyOf(train, p);
            test = Arrays.copyOf(test, q);
            MathEx.permutate(train);
            MathEx.permutate(test);
            bags[i] = new Bag(train, test);
        }
        return bags;
    }

    public static Bag[] nonoverlap(int[] group, int k) {
        int i;
        if (k < 0) {
            throw new IllegalArgumentException("Invalid number of folds: " + k);
        }
        int[] unique = MathEx.unique(group);
        int m = unique.length;
        if (k > m) {
            throw new IllegalArgumentException("k-fold must be not greater than the than number of groups");
        }
        Arrays.sort(unique);
        IntSet encoder = new IntSet(unique);
        int n = group.length;
        int[] y = group;
        if (unique[0] != 0 || unique[m - 1] != m - 1) {
            y = new int[n];
            for (int i2 = 0; i2 < n; ++i2) {
                y[i2] = encoder.indexOf(group[i2]);
            }
        }
        int[] ni = new int[m];
        int[] nArray = y;
        int n2 = nArray.length;
        for (int j = 0; j < n2; ++j) {
            int n3 = i = nArray[j];
            ni[n3] = ni[n3] + 1;
        }
        int[] index = QuickSort.sort(ni);
        int[] foldSize = new int[k];
        int[] group2Fold = new int[m];
        for (i = m - 1; i >= 0; --i) {
            int smallestFold;
            int n4 = smallestFold = MathEx.whichMin(foldSize);
            foldSize[n4] = foldSize[n4] + ni[i];
            group2Fold[index[i]] = smallestFold;
        }
        Bag[] bags = new Bag[k];
        for (int i3 = 0; i3 < k; ++i3) {
            int[] train = new int[n - foldSize[i3]];
            int[] test = new int[foldSize[i3]];
            bags[i3] = new Bag(train, test);
            int trainIndex = 0;
            int testIndex = 0;
            for (int j = 0; j < n; ++j) {
                if (group2Fold[y[j]] == i3) {
                    test[testIndex++] = j;
                    continue;
                }
                train[trainIndex++] = j;
            }
        }
        return bags;
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> classification(int k, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        Bag[] bags = CrossValidation.of(x.length, k);
        return ClassificationValidation.of(bags, x, y, trainer);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> classification(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        Bag[] bags = CrossValidation.of(data.size(), k);
        return ClassificationValidation.of(bags, formula, data, trainer);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> classification(int round, int k, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        if (round < 1) {
            throw new IllegalArgumentException("Invalid round: " + round);
        }
        Bag[] bags = (Bag[])IntStream.range(0, round).mapToObj(i -> CrossValidation.of(x.length, k)).flatMap(Arrays::stream).toArray(Bag[]::new);
        return ClassificationValidation.of(bags, x, y, trainer);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> classification(int round, int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        if (round < 1) {
            throw new IllegalArgumentException("Invalid round: " + round);
        }
        Bag[] bags = (Bag[])IntStream.range(0, round).mapToObj(i -> CrossValidation.of(data.size(), k)).flatMap(Arrays::stream).toArray(Bag[]::new);
        return ClassificationValidation.of(bags, formula, data, trainer);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> stratify(int k, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        Bag[] bags = CrossValidation.stratify(y, k);
        return ClassificationValidation.of(bags, x, y, trainer);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> stratify(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        int[] y = formula.y(data).toIntArray();
        Bag[] bags = CrossValidation.stratify(y, k);
        return ClassificationValidation.of(bags, formula, data, trainer);
    }

    public static <T, M extends Classifier<T>> ClassificationValidations<M> stratify(int round, int k, T[] x, int[] y, BiFunction<T[], int[], M> trainer) {
        if (round < 1) {
            throw new IllegalArgumentException("Invalid round: " + round);
        }
        Bag[] bags = (Bag[])IntStream.range(0, round).mapToObj(i -> CrossValidation.stratify(y, k)).flatMap(Arrays::stream).toArray(Bag[]::new);
        return ClassificationValidation.of(bags, x, y, trainer);
    }

    public static <M extends DataFrameClassifier> ClassificationValidations<M> stratify(int round, int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        if (round < 1) {
            throw new IllegalArgumentException("Invalid round: " + round);
        }
        int[] y = formula.y(data).toIntArray();
        Bag[] bags = (Bag[])IntStream.range(0, round).mapToObj(i -> CrossValidation.stratify(y, k)).flatMap(Arrays::stream).toArray(Bag[]::new);
        return ClassificationValidation.of(bags, formula, data, trainer);
    }

    public static <T, M extends Regression<T>> RegressionValidations<M> regression(int k, T[] x, double[] y, BiFunction<T[], double[], M> trainer) {
        Bag[] bags = CrossValidation.of(x.length, k);
        return RegressionValidation.of(bags, x, y, trainer);
    }

    public static <M extends DataFrameRegression> RegressionValidations<M> regression(int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        Bag[] bags = CrossValidation.of(data.size(), k);
        return RegressionValidation.of(bags, formula, data, trainer);
    }

    public static <T, M extends Regression<T>> RegressionValidations<M> regression(int round, int k, T[] x, double[] y, BiFunction<T[], double[], M> trainer) {
        if (round < 1) {
            throw new IllegalArgumentException("Invalid round: " + round);
        }
        Bag[] bags = (Bag[])IntStream.range(0, round).mapToObj(i -> CrossValidation.of(x.length, k)).flatMap(Arrays::stream).toArray(Bag[]::new);
        return RegressionValidation.of(bags, x, y, trainer);
    }

    public static <M extends DataFrameRegression> RegressionValidations<M> regression(int round, int k, Formula formula, DataFrame data, BiFunction<Formula, DataFrame, M> trainer) {
        if (round < 1) {
            throw new IllegalArgumentException("Invalid round: " + round);
        }
        Bag[] bags = (Bag[])IntStream.range(0, round).mapToObj(i -> CrossValidation.of(data.size(), k)).flatMap(Arrays::stream).toArray(Bag[]::new);
        return RegressionValidation.of(bags, formula, data, trainer);
    }
}

