/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Properties;
import smile.classification.AbstractClassifier;
import smile.classification.DiscriminantAnalysis;
import smile.math.MathEx;
import smile.math.matrix.Matrix;
import smile.sort.QuickSort;
import smile.util.IntSet;

public class FLD
extends AbstractClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int p;
    private final int k;
    private final Matrix scaling;
    private final double[] mean;
    private final double[][] mu;

    public FLD(double[] mean, double[][] mu, Matrix scaling) {
        this(mean, mu, scaling, IntSet.of(mu.length));
    }

    public FLD(double[] mean, double[][] mu, Matrix scaling, IntSet labels) {
        super(labels);
        this.k = mu.length;
        this.p = mean.length;
        this.scaling = scaling;
        int L = scaling.ncol();
        this.mean = new double[L];
        scaling.tv(mean, this.mean);
        this.mu = new double[this.k][L];
        for (int i = 0; i < this.k; ++i) {
            scaling.tv(mu[i], this.mu[i]);
        }
    }

    public static FLD fit(double[][] x, int[] y) {
        return FLD.fit(x, y, -1, 1.0E-4);
    }

    public static FLD fit(double[][] x, int[] y, Properties params) {
        int L = Integer.parseInt(params.getProperty("smile.fisher.dimension", "-1"));
        double tol = Double.parseDouble(params.getProperty("smile.fisher.tolerance", "1E-4"));
        return FLD.fit(x, y, L, tol);
    }

    public static FLD fit(double[][] x, int[] y, int L, double tol) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        DiscriminantAnalysis da = DiscriminantAnalysis.fit(x, y, null, tol);
        int n = x.length;
        int k = da.k;
        int p = da.mean.length;
        if (L >= k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", L, k));
        }
        if (L <= 0) {
            L = k - 1;
        }
        double[] mean = da.mean;
        double[][] mu = da.mu;
        Matrix scaling = n - k < p ? FLD.small(L, x, mean, mu, da.priori, tol) : FLD.fld(L, x, mean, mu, tol);
        return new FLD(mean, mu, scaling, da.labels);
    }

    private static Matrix fld(int L, double[][] x, double[] mean, double[][] mu, double tol) {
        int k = mu.length;
        int p = mean.length;
        Matrix St = DiscriminantAnalysis.St(x, mean, k, tol);
        for (double[] mui : mu) {
            for (int j = 0; j < p; ++j) {
                int n = j;
                mui[n] = mui[n] - mean[j];
            }
        }
        Matrix Sb = new Matrix(p, p);
        for (double[] mui : mu) {
            for (int j = 0; j < p; ++j) {
                for (int i = 0; i <= j; ++i) {
                    Sb.add(i, j, mui[i] * mui[j]);
                }
            }
        }
        for (int j = 0; j < p; ++j) {
            for (int i = 0; i <= j; ++i) {
                Sb.div(i, j, k);
                Sb.set(j, i, Sb.get(i, j));
            }
        }
        Matrix Sw = St.sub(Sb);
        Matrix SwInvSb = Sw.inverse().mm(Sb);
        Matrix.EVD evd = SwInvSb.eigen(false, true, true);
        double[] w = new double[p];
        for (int i = 0; i < p; ++i) {
            w[i] = -(evd.wr[i] * evd.wr[i] + evd.wi[i] * evd.wi[i]);
        }
        int[] index = QuickSort.sort(w);
        Matrix scaling = new Matrix(p, L);
        for (int j = 0; j < L; ++j) {
            int l = index[j];
            for (int i = 0; i < p; ++i) {
                scaling.set(i, j, evd.Vr.get(i, l));
            }
        }
        return scaling;
    }

    private static Matrix small(int L, double[][] x, double[] mean, double[][] mu, double[] priori, double tol) {
        int k = mu.length;
        int p = mean.length;
        int n = x.length;
        double sqrtn = Math.sqrt(n);
        Matrix X = new Matrix(p, n);
        for (int i = 0; i < n; ++i) {
            double[] xi = x[i];
            for (int j = 0; j < p; ++j) {
                X.set(j, i, (xi[j] - mean[j]) / sqrtn);
            }
        }
        for (double[] mui : mu) {
            for (int j = 0; j < p; ++j) {
                int n2 = j;
                mui[n2] = mui[n2] - mean[j];
            }
        }
        Matrix M4 = new Matrix(p, k);
        for (int i = 0; i < k; ++i) {
            double pi = Math.sqrt(priori[i]);
            double[] mui = mu[i];
            for (int j = 0; j < p; ++j) {
                M4.set(j, i, pi * mui[j]);
            }
        }
        Matrix.SVD svd = X.svd(true, true);
        Matrix U = svd.U;
        double[] s = svd.s;
        tol *= tol;
        Matrix UTM = U.tm(M4);
        for (int i = 0; i < n; ++i) {
            double si = 0.0;
            if (s[i] > tol) {
                si = 1.0 / Math.sqrt(s[i]);
            }
            for (int j = 0; j < k; ++j) {
                UTM.mul(i, j, si);
            }
        }
        Matrix StInvM = U.mm(UTM);
        Matrix U2 = U.tm(StInvM.svd((boolean)true, (boolean)true).U.submatrix(0, 0, p - 1, L - 1));
        for (int i = 0; i < n; ++i) {
            double si = 0.0;
            if (s[i] > tol) {
                si = 1.0 / Math.sqrt(s[i]);
            }
            for (int j = 0; j < L; ++j) {
                U2.mul(i, j, si);
            }
        }
        return U.mm(U2);
    }

    @Override
    public int predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] wx = this.project(x);
        int y = 0;
        double nearest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double d = MathEx.distance(wx, this.mu[i]);
            if (!(d < nearest)) continue;
            nearest = d;
            y = i;
        }
        return this.classes.valueOf(y);
    }

    public double[] project(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] y = this.scaling.tv(x);
        MathEx.sub(y, this.mean);
        return y;
    }

    public double[][] project(double[][] x) {
        double[][] y = new double[x.length][this.scaling.ncol()];
        for (int i = 0; i < x.length; ++i) {
            if (x[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[i].length, this.p));
            }
            this.scaling.tv(x[i], y[i]);
            MathEx.sub(y[i], this.mean);
        }
        return y;
    }

    public Matrix getProjection() {
        return this.scaling;
    }
}

