/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.function.BiFunction;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.classification.Classifier;
import smile.classification.DataFrameClassifier;
import smile.classification.PlattScaling;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.util.IntSet;

public class OneVersusOne<T>
extends AbstractClassifier<T> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(OneVersusOne.class);
    private final int k;
    private final Classifier<T>[][] classifiers;
    private final PlattScaling[][] platt;

    public OneVersusOne(Classifier<T>[][] classifiers, PlattScaling[][] platt) {
        this(classifiers, platt, IntSet.of(classifiers.length));
    }

    public OneVersusOne(Classifier<T>[][] classifiers, PlattScaling[][] platt, IntSet labels) {
        super(labels);
        this.classifiers = classifiers;
        this.platt = platt;
        this.k = classifiers.length;
    }

    public static <T> OneVersusOne<T> fit(T[] x, int[] y, BiFunction<T[], int[], Classifier<T>> trainer) {
        return OneVersusOne.fit(x, y, 1, -1, trainer);
    }

    public static <T> OneVersusOne<T> fit(T[] x, int[] y, int pos, int neg, BiFunction<T[], int[], Classifier<T>> trainer) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        if (k <= 2) {
            throw new IllegalArgumentException(String.format("Only %d classes", k));
        }
        int[] ni = codec.ni;
        int[] labels = codec.y;
        Classifier[][] classifiers = new Classifier[k][];
        PlattScaling[][] platts = new PlattScaling[k][];
        for (int i = 1; i < k; ++i) {
            classifiers[i] = new Classifier[i];
            platts[i] = new PlattScaling[i];
        }
        IntStream.range(0, k * (k - 1) / 2).parallel().forEach(index -> {
            int j = k - 2 - (int)Math.floor(Math.sqrt(-8 * index + 4 * k * (k - 1) - 7) / 2.0 - 0.5);
            int i = index + j + 1 - k * (k - 1) / 2 + (k - j) * (k - j - 1) / 2;
            int n = ni[i] + ni[j];
            Object[] xij = (Object[])Array.newInstance(x.getClass().getComponentType(), n);
            int[] yij = new int[n];
            int q = 0;
            for (int l = 0; l < labels.length; ++l) {
                if (labels[l] == i) {
                    xij[q] = x[l];
                    yij[q] = pos;
                    ++q;
                    continue;
                }
                if (labels[l] != j) continue;
                xij[q] = x[l];
                yij[q] = neg;
                ++q;
            }
            classifiers[i][j] = (Classifier)trainer.apply(xij, yij);
            try {
                platts[i][j] = PlattScaling.fit(classifiers[i][j], xij, yij);
            }
            catch (UnsupportedOperationException ex) {
                logger.info("The classifier doesn't support score function. Don't fit Platt scaling.");
            }
        });
        return new OneVersusOne<T>(classifiers, platts[1][0] == null ? null : platts);
    }

    public static DataFrameClassifier fit(final Formula formula, DataFrame data, BiFunction<Formula, DataFrame, DataFrameClassifier> trainer) {
        Tuple[] x = (Tuple[])data.stream().toArray(Tuple[]::new);
        int[] y = formula.y(data).toIntArray();
        final OneVersusOne<Tuple> model = OneVersusOne.fit(x, y, 1, 0, (rows, labels) -> {
            DataFrame df = DataFrame.of(Arrays.asList(rows));
            return (Classifier)trainer.apply(formula, df);
        });
        final StructType schema = formula.x((Tuple)data.get(0)).schema();
        return new DataFrameClassifier(){

            @Override
            public int numClasses() {
                return model.numClasses();
            }

            @Override
            public int[] classes() {
                return model.classes();
            }

            @Override
            public int predict(Tuple x) {
                return model.predict(x);
            }

            @Override
            public Formula formula() {
                return formula;
            }

            @Override
            public StructType schema() {
                return schema;
            }
        };
    }

    @Override
    public int predict(T x) {
        int[] count = new int[this.k];
        for (int i = 1; i < this.k; ++i) {
            for (int j = 0; j < i; ++j) {
                if (this.classifiers[i][j].predict(x) > 0) {
                    int n = i;
                    count[n] = count[n] + 1;
                    continue;
                }
                int n = j;
                count[n] = count[n] + 1;
            }
        }
        return this.classes.valueOf(MathEx.whichMax(count));
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(T x, double[] posteriori) {
        if (this.platt == null) {
            throw new UnsupportedOperationException("Platt scaling is not available");
        }
        double[][] r = new double[this.k][this.k];
        for (int i = 1; i < this.k; ++i) {
            for (int j = 0; j < i; ++j) {
                r[i][j] = this.platt[i][j].scale(this.classifiers[i][j].score(x));
                r[j][i] = 1.0 - r[i][j];
            }
        }
        this.coupling(r, posteriori);
        return this.classes.valueOf(MathEx.whichMax(posteriori));
    }

    private void coupling(double[][] r, double[] p) {
        int iter;
        double[][] Q = new double[this.k][this.k];
        double[] Qp = new double[this.k];
        double eps = 0.005 / (double)this.k;
        for (int t2 = 0; t2 < this.k; ++t2) {
            int j;
            p[t2] = 1.0 / (double)this.k;
            Q[t2][t2] = 0.0;
            for (j = 0; j < t2; ++j) {
                double[] dArray = Q[t2];
                int n = t2;
                dArray[n] = dArray[n] + r[j][t2] * r[j][t2];
                Q[t2][j] = Q[j][t2];
            }
            for (j = t2 + 1; j < this.k; ++j) {
                double[] dArray = Q[t2];
                int n = t2;
                dArray[n] = dArray[n] + r[j][t2] * r[j][t2];
                Q[t2][j] = -r[j][t2] * r[t2][j];
            }
        }
        int maxIter = Math.max(100, this.k);
        for (iter = 0; iter < maxIter; ++iter) {
            int t3;
            double pQp = 0.0;
            for (int t4 = 0; t4 < this.k; ++t4) {
                Qp[t4] = 0.0;
                for (int j = 0; j < this.k; ++j) {
                    int n = t4;
                    Qp[n] = Qp[n] + Q[t4][j] * p[j];
                }
                pQp += p[t4] * Qp[t4];
            }
            double max_error = 0.0;
            for (t3 = 0; t3 < this.k; ++t3) {
                double error = Math.abs(Qp[t3] - pQp);
                if (!(error > max_error)) continue;
                max_error = error;
            }
            if (max_error < eps) break;
            for (t3 = 0; t3 < this.k; ++t3) {
                double diff = (-Qp[t3] + pQp) / Q[t3][t3];
                int n = t3;
                p[n] = p[n] + diff;
                pQp = (pQp + diff * (diff * Q[t3][t3] + 2.0 * Qp[t3])) / (1.0 + diff) / (1.0 + diff);
                int j = 0;
                while (j < this.k) {
                    Qp[j] = (Qp[j] + diff * Q[t3][j]) / (1.0 + diff);
                    int n2 = j++;
                    p[n2] = p[n2] / (1.0 + diff);
                }
            }
        }
        if (iter >= maxIter) {
            logger.warn("coupling reaches maximal iterations");
        }
    }
}

