package smile.feature.extraction;

import smile.data.DataFrame;
import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/feature/extraction/ProbabilisticPCA.class */
public class ProbabilisticPCA extends Projection {
    private static final long serialVersionUID = 2;
    private final double[] mu;
    private final double[] pmu;
    private final double noise;
    private final Matrix loading;

    public ProbabilisticPCA(double d, double[] dArr, Matrix matrix, Matrix matrix2, String... strArr) {
        super(matrix2, "PPCA", strArr);
        this.noise = d;
        this.mu = dArr;
        this.loading = matrix;
        this.pmu = new double[matrix2.nrow()];
        matrix2.mv(dArr, this.pmu);
    }

    public Matrix loadings() {
        return this.loading;
    }

    public double[] center() {
        return this.mu;
    }

    public double variance() {
        return this.noise;
    }

    @Override // smile.feature.extraction.Projection
    protected double[] postprocess(double[] dArr) {
        MathEx.sub(dArr, this.pmu);
        return dArr;
    }

    public static ProbabilisticPCA fit(DataFrame dataFrame, int i, String... strArr) {
        return fit(dataFrame.toArray(strArr), i, strArr);
    }

    public static ProbabilisticPCA fit(double[][] dArr, int i, String... strArr) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[] colMeans = MathEx.colMeans(dArr);
        Matrix matrix = new Matrix(length2, length2);
        for (double[] dArr2 : dArr) {
            for (int i2 = 0; i2 < length2; i2++) {
                for (int i3 = 0; i3 <= i2; i3++) {
                    matrix.add(i2, i3, (dArr2[i2] - colMeans[i2]) * (dArr2[i3] - colMeans[i3]));
                }
            }
        }
        for (int i4 = 0; i4 < length2; i4++) {
            for (int i5 = 0; i5 <= i4; i5++) {
                matrix.div(i4, i5, length);
                matrix.set(i5, i4, matrix.get(i4, i5));
            }
        }
        matrix.uplo(UPLO.LOWER);
        Matrix.EVD sort = matrix.eigen(false, true, true).sort();
        double[] dArr3 = sort.wr;
        Matrix matrix2 = sort.Vr;
        double d = 0.0d;
        for (int i6 = i; i6 < length2; i6++) {
            d += dArr3[i6];
        }
        double d2 = d / (length2 - i);
        Matrix matrix3 = new Matrix(length2, i);
        for (int i7 = 0; i7 < length2; i7++) {
            for (int i8 = 0; i8 < i; i8++) {
                matrix3.set(i7, i8, matrix2.get(i7, i8) * Math.sqrt(dArr3[i8] - d2));
            }
        }
        Matrix ata = matrix3.ata();
        ata.addDiag(d2);
        return new ProbabilisticPCA(d2, colMeans, matrix3, ata.cholesky(true).inverse().mt(matrix3), strArr);
    }
}
