package smile.manifold;

import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.stat.distribution.GaussianDistribution;

/* loaded from: input_file:smile/manifold/TSNE.class */
public class TSNE implements Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(TSNE.class);
    public final double[][] coordinates;
    private final double eta;
    private int totalIter;
    private double momentum;
    private final double finalMomentum = 0.8d;
    private final int momentumSwitchIter = 250;
    private final double minGain = 0.01d;
    private final double[][] gains;
    private final double[][] P;
    private final double[][] Q;
    private double Qsum;
    private double cost;

    public TSNE(double[][] dArr, int i) {
        this(dArr, i, 20.0d, 200.0d, 1000);
    }

    public TSNE(double[][] dArr, int i, double d, double d2, int i2) {
        double[][] dArr2;
        this.totalIter = 0;
        this.momentum = 0.5d;
        this.finalMomentum = 0.8d;
        this.momentumSwitchIter = 250;
        this.minGain = 0.01d;
        this.eta = d2;
        int length = dArr.length;
        if (dArr.length == dArr[0].length) {
            dArr2 = dArr;
        } else {
            dArr2 = new double[length][length];
            MathEx.pdist(dArr, dArr2, MathEx::squaredDistance);
        }
        this.coordinates = new double[length][i];
        double[][] dArr3 = this.coordinates;
        this.gains = new double[length][i];
        GaussianDistribution gaussianDistribution = new GaussianDistribution(CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0E-4d);
        for (int i3 = 0; i3 < length; i3++) {
            Arrays.fill(this.gains[i3], 1.0d);
            double[] dArr4 = dArr3[i3];
            for (int i4 = 0; i4 < i; i4++) {
                dArr4[i4] = gaussianDistribution.rand();
            }
        }
        this.P = expd(dArr2, d, 0.001d);
        this.Q = new double[length][length];
        double d3 = 2 * length;
        for (int i5 = 0; i5 < length; i5++) {
            double[] dArr5 = this.P[i5];
            for (int i6 = 0; i6 < i5; i6++) {
                double d4 = (12.0d * (dArr5[i6] + this.P[i6][i5])) / d3;
                if (Double.isNaN(d4) || d4 < 1.0E-16d) {
                    d4 = 1.0E-16d;
                }
                dArr5[i6] = d4;
                this.P[i6][i5] = d4;
            }
        }
        update(i2);
    }

    public double cost() {
        return this.cost;
    }

    public void update(int i) {
        double[][] dArr = this.coordinates;
        int length = dArr.length;
        int length2 = dArr[0].length;
        double[][] dArr2 = new double[length][length2];
        double[][] dArr3 = new double[length][length2];
        int i2 = 1;
        while (i2 <= i) {
            this.Qsum = computeQ(dArr, this.Q);
            IntStream.range(0, length).parallel().forEach(i3 -> {
                sne(i3, dArr2[i3], dArr3[i3]);
            });
            IntStream.range(0, length).parallel().forEach(i4 -> {
                double[] dArr4 = dArr[i4];
                double[] dArr5 = dArr2[i4];
                double[] dArr6 = dArr3[i4];
                double[] dArr7 = this.gains[i4];
                for (int i4 = 0; i4 < length2; i4++) {
                    dArr5[i4] = (this.momentum * dArr5[i4]) - ((this.eta * dArr7[i4]) * dArr6[i4]);
                    int i5 = i4;
                    dArr4[i5] = dArr4[i5] + dArr5[i4];
                }
            });
            if (this.totalIter == 250) {
                this.momentum = 0.8d;
                for (int i5 = 0; i5 < length; i5++) {
                    double[] dArr4 = this.P[i5];
                    for (int i6 = 0; i6 < length; i6++) {
                        int i7 = i6;
                        dArr4[i7] = dArr4[i7] / 12.0d;
                    }
                }
            }
            if (i2 % 100 == 0) {
                this.cost = computeCost(this.P, this.Q);
                logger.info("Error after {} iterations: {}", Integer.valueOf(i2), Double.valueOf(this.cost));
            }
            i2++;
            this.totalIter++;
        }
        double[] colMeans = MathEx.colMeans(dArr);
        IntStream.range(0, length).parallel().forEach(i8 -> {
            double[] dArr5 = dArr[i8];
            for (int i8 = 0; i8 < length2; i8++) {
                int i9 = i8;
                dArr5[i9] = dArr5[i9] - colMeans[i8];
            }
        });
        if (i % 100 != 0) {
            this.cost = computeCost(this.P, this.Q);
            logger.info("Error after {} iterations: {}", Integer.valueOf(i), Double.valueOf(this.cost));
        }
    }

    private void sne(int i, double[] dArr, double[] dArr2) {
        double[][] dArr3 = this.coordinates;
        int length = dArr3.length;
        int length2 = dArr3[0].length;
        double[] dArr4 = dArr3[i];
        double[] dArr5 = this.P[i];
        double[] dArr6 = this.Q[i];
        double[] dArr7 = this.gains[i];
        Arrays.fill(dArr2, CMAESOptimizer.DEFAULT_STOPFITNESS);
        for (int i2 = 0; i2 < length; i2++) {
            if (i != i2) {
                double[] dArr8 = dArr3[i2];
                double d = dArr6[i2];
                double d2 = (dArr5[i2] - (d / this.Qsum)) * d;
                for (int i3 = 0; i3 < length2; i3++) {
                    int i4 = i3;
                    dArr2[i4] = dArr2[i4] + (4.0d * (dArr4[i3] - dArr8[i3]) * d2);
                }
            }
        }
        for (int i5 = 0; i5 < length2; i5++) {
            dArr7[i5] = Math.signum(dArr2[i5]) != Math.signum(dArr[i5]) ? dArr7[i5] + 0.2d : dArr7[i5] * 0.8d;
            if (dArr7[i5] < 0.01d) {
                dArr7[i5] = 0.01d;
            }
        }
    }

    private double[][] expd(double[][] dArr, double d, double d2) {
        int length = dArr.length;
        double[][] dArr2 = new double[length][length];
        double[] rowSums = MathEx.rowSums(dArr);
        IntStream.range(0, length).parallel().forEach(i -> {
            double log2 = MathEx.log2(d);
            double[] dArr3 = dArr2[i];
            double[] dArr4 = dArr[i];
            double sqrt = Math.sqrt((length - 1) / rowSums[i]);
            double d3 = 0.0d;
            double d4 = Double.POSITIVE_INFINITY;
            logger.debug("initial beta[{}] = {}", Integer.valueOf(i), Double.valueOf(sqrt));
            double d5 = Double.MAX_VALUE;
            for (int i = 0; Math.abs(d5) > d2 && i < 50; i++) {
                double d6 = 0.0d;
                double d7 = 0.0d;
                for (int i2 = 0; i2 < length; i2++) {
                    double d8 = sqrt * dArr4[i2];
                    double exp = Math.exp(-d8);
                    dArr3[i2] = exp;
                    d6 += exp;
                    d7 += exp * d8;
                }
                dArr3[i] = 0.0d;
                double d9 = d6 - 1.0d;
                double log22 = MathEx.log2(d9) + (d7 / d9);
                d5 = log22 - log2;
                if (Math.abs(d5) <= d2) {
                    for (int i3 = 0; i3 < length; i3++) {
                        int i4 = i3;
                        dArr3[i4] = dArr3[i4] / d9;
                    }
                } else if (d5 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d3 = sqrt;
                    sqrt = Double.isInfinite(d4) ? sqrt * 2.0d : (sqrt + d4) / 2.0d;
                } else {
                    d4 = sqrt;
                    sqrt = (sqrt + d3) / 2.0d;
                }
                logger.debug("Hdiff = {}, beta[{}] = {}, H = {}, logU = {}", Double.valueOf(d5), Integer.valueOf(i), Double.valueOf(sqrt), Double.valueOf(log22), Double.valueOf(log2));
            }
        });
        return dArr2;
    }

    private double computeQ(double[][] dArr, double[][] dArr2) {
        int length = dArr.length;
        return MathEx.sum(IntStream.range(0, length).parallel().mapToDouble(i -> {
            double[] dArr3 = dArr[i];
            double[] dArr4 = dArr2[i];
            double d = 0.0d;
            for (int i = 0; i < length; i++) {
                double squaredDistance = 1.0d / (1.0d + MathEx.squaredDistance(dArr3, dArr[i]));
                dArr4[i] = squaredDistance;
                d += squaredDistance;
            }
            return d;
        }).toArray());
    }

    private double computeCost(double[][] dArr, double[][] dArr2) {
        return 2.0d * IntStream.range(0, dArr2.length).parallel().mapToDouble(i -> {
            double[] dArr3 = dArr[i];
            double[] dArr4 = dArr2[i];
            double d = 0.0d;
            for (int i = 0; i < i; i++) {
                double d2 = dArr3[i];
                double d3 = dArr4[i] / this.Qsum;
                if (Double.isNaN(d3) || d3 < 1.0E-16d) {
                    d3 = 1.0E-16d;
                }
                d += d2 * MathEx.log2(d2 / d3);
            }
            return d;
        }).sum();
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1421810692:
                if (implMethodName.equals("squaredDistance")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("smile/math/distance/Distance") && serializedLambda.getFunctionalInterfaceMethodName().equals("d") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;Ljava/lang/Object;)D") && serializedLambda.getImplClass().equals("smile/math/MathEx") && serializedLambda.getImplMethodSignature().equals("([D[D)D")) {
                    return MathEx::squaredDistance;
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
