package smile.manifold;

import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.Objects;
import java.util.stream.IntStream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.graph.AdjacencyList;
import smile.graph.Graph;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.LevenbergMarquardt;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.math.distance.EuclideanDistance;
import smile.math.matrix.ARPACK;
import smile.math.matrix.Matrix;
import smile.math.matrix.SparseMatrix;
import smile.stat.distribution.GaussianDistribution;

/* loaded from: input_file:smile/manifold/UMAP.class */
public class UMAP implements Serializable {
    private static final long serialVersionUID = 2;
    public final double[][] coordinates;
    public final int[] index;
    public final AdjacencyList graph;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) UMAP.class);
    private static final DifferentiableMultivariateFunction func = new DifferentiableMultivariateFunction() { // from class: smile.manifold.UMAP.1
        @Override // smile.math.MultivariateFunction
        public double f(double[] dArr) {
            return 1.0d / (1.0d + (dArr[0] * Math.pow(dArr[2], dArr[1])));
        }

        @Override // smile.math.DifferentiableMultivariateFunction
        public double g(double[] dArr, double[] dArr2) {
            double pow = Math.pow(dArr[2], dArr[1]);
            double d = 1.0d + (dArr[0] * pow);
            dArr2[0] = (-pow) / (d * d);
            dArr2[1] = (-(((dArr[0] * dArr[1]) * Math.log(dArr[2])) * pow)) / (d * d);
            return 1.0d / d;
        }
    };

    public UMAP(int[] iArr, double[][] dArr, AdjacencyList adjacencyList) {
        this.index = iArr;
        this.coordinates = dArr;
        this.graph = adjacencyList;
    }

    public static UMAP of(double[][] dArr) {
        return of(dArr, 15);
    }

    public static <T> UMAP of(T[] tArr, Distance<T> distance) {
        return of(tArr, distance, 15);
    }

    public static UMAP of(double[][] dArr, int i) {
        return of(dArr, new EuclideanDistance(), i);
    }

    public static <T> UMAP of(T[] tArr, Distance<T> distance, int i) {
        return of(tArr, distance, i, 2, tArr.length > 10000 ? 200 : 500, 1.0d, 0.1d, 1.0d, 5, 1.0d);
    }

    public static UMAP of(double[][] dArr, int i, int i2, int i3, double d, double d2, double d3, int i4, double d4) {
        return of(dArr, new EuclideanDistance(), i, i2, i3, d, d2, d3, i4, d4);
    }

    public static <T> UMAP of(T[] tArr, Distance<T> distance, int i, int i2, int i3, double d, double d2, double d3, int i4, double d4) {
        if (i2 < 2) {
            throw new IllegalArgumentException("d must be greater than 1: " + i2);
        }
        if (i < 2) {
            throw new IllegalArgumentException("k must be greater than 1: " + i);
        }
        if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("minDist must greater than 0: " + d2);
        }
        if (d2 > d3) {
            throw new IllegalArgumentException("minDist must be less than or equal to spread: " + d2 + ",spread=" + d3);
        }
        if (i3 < 10) {
            throw new IllegalArgumentException("epochs must be a positive integer of at least 10: " + i3);
        }
        if (d <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("learningRate must greater than 0: " + d);
        }
        if (i4 <= 0) {
            throw new IllegalArgumentException("negativeSamples must greater than 0: " + i4);
        }
        NearestNeighborGraph largest = NearestNeighborGraph.largest(NearestNeighborGraph.of(tArr, distance, i, true, null));
        AdjacencyList computeFuzzySimplicialSet = computeFuzzySimplicialSet(largest.graph, i, 64);
        SparseMatrix matrix = computeFuzzySimplicialSet.toMatrix();
        double[][] spectralLayout = spectralLayout(computeFuzzySimplicialSet, i2);
        logger.info("Finish initialization with spectral layout");
        double[] fitCurve = fitCurve(d3, d2);
        logger.info("Finish fitting the curve parameters");
        SparseMatrix computeEpochPerSample = computeEpochPerSample(matrix, i3);
        logger.info("Start optimizing the layout");
        optimizeLayout(spectralLayout, fitCurve, computeEpochPerSample, i3, d, i4, d4);
        return new UMAP(largest.index, spectralLayout, computeFuzzySimplicialSet);
    }

    private static double[] fitCurve(double d, double d2) {
        double[] dArr = new double[300];
        double[] dArr2 = new double[300];
        double d3 = (3.0d * d) / 300;
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (i + 1) * d3;
            dArr2[i] = dArr[i] < d2 ? 1.0d : Math.exp((-(dArr[i] - d2)) / d);
        }
        return LevenbergMarquardt.fit(func, dArr, dArr2, new double[]{0.5d, CMAESOptimizer.DEFAULT_STOPFITNESS}).parameters;
    }

    private static AdjacencyList computeFuzzySimplicialSet(AdjacencyList adjacencyList, int i, int i2) {
        double d;
        double log2 = MathEx.log2(i);
        int numVertices = adjacencyList.getNumVertices();
        double[] dArr = new double[numVertices];
        double[] dArr2 = new double[numVertices];
        IntStream range = IntStream.range(0, numVertices);
        Objects.requireNonNull(adjacencyList);
        double orElse = range.mapToObj(adjacencyList::getEdges).flatMapToDouble(collection -> {
            return collection.stream().mapToDouble(edge -> {
                return edge.weight;
            });
        }).filter(d2 -> {
            return !MathEx.isZero(d2, 1.0E-8d);
        }).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS);
        for (int i3 = 0; i3 < numVertices; i3++) {
            double d3 = 0.0d;
            double d4 = Double.POSITIVE_INFINITY;
            double d5 = 1.0d;
            Collection<Graph.Edge> edges = adjacencyList.getEdges(i3);
            dArr2[i3] = edges.stream().mapToDouble(edge -> {
                return edge.weight;
            }).filter(d6 -> {
                return !MathEx.isZero(d6, 1.0E-8d);
            }).min().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS);
            for (int i4 = 0; i4 < i2; i4++) {
                double d7 = 0.0d;
                for (Graph.Edge edge2 : edges) {
                    if (!MathEx.isZero(edge2.weight, 1.0E-8d)) {
                        double d8 = edge2.weight - dArr2[i3];
                        d7 += d8 > CMAESOptimizer.DEFAULT_STOPFITNESS ? Math.exp((-d8) / d5) : 1.0d;
                    }
                }
                if (Math.abs(d7 - log2) < 1.0E-5d) {
                    break;
                }
                if (d7 > log2) {
                    d4 = d5;
                    d = (d3 + d4) / 2.0d;
                } else {
                    d3 = d5;
                    d = Double.isInfinite(d4) ? d5 * 2.0d : (d3 + d4) / 2.0d;
                }
                d5 = d;
            }
            dArr[i3] = d5;
            if (dArr2[i3] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                dArr[i3] = Math.max(dArr[i3], 0.001d * edges.stream().mapToDouble(edge3 -> {
                    return edge3.weight;
                }).filter(d9 -> {
                    return !MathEx.isZero(d9, 1.0E-8d);
                }).average().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS));
            } else {
                dArr[i3] = Math.max(dArr[i3], 0.001d * orElse);
            }
        }
        for (int i5 = 0; i5 < numVertices; i5++) {
            for (Graph.Edge edge4 : adjacencyList.getEdges(i5)) {
                edge4.weight = Math.exp((-Math.max(CMAESOptimizer.DEFAULT_STOPFITNESS, edge4.weight - dArr2[i5])) / dArr[i5]);
            }
        }
        AdjacencyList adjacencyList2 = new AdjacencyList(numVertices, false);
        for (int i6 = 0; i6 < numVertices; i6++) {
            for (Graph.Edge edge5 : adjacencyList.getEdges(i6)) {
                double d10 = edge5.weight;
                double weight = adjacencyList.getWeight(edge5.v2, edge5.v1);
                adjacencyList2.setWeight(edge5.v1, edge5.v2, (d10 + weight) - (d10 * weight));
            }
        }
        return adjacencyList2;
    }

    private static double[][] spectralLayout(AdjacencyList adjacencyList, int i) {
        int numVertices = adjacencyList.getNumVertices();
        double[] dArr = new double[numVertices];
        for (int i2 = 0; i2 < numVertices; i2++) {
            Iterator<Graph.Edge> it = adjacencyList.getEdges(i2).iterator();
            while (it.hasNext()) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + it.next().weight;
            }
            dArr[i2] = 1.0d / Math.sqrt(dArr[i2]);
        }
        AdjacencyList adjacencyList2 = new AdjacencyList(numVertices, false);
        for (int i4 = 0; i4 < numVertices; i4++) {
            adjacencyList2.setWeight(i4, i4, 1.0d);
            for (Graph.Edge edge : adjacencyList.getEdges(i4)) {
                adjacencyList2.setWeight(edge.v1, edge.v2, (-dArr[edge.v1]) * edge.weight * dArr[edge.v2]);
            }
        }
        double d = 0.0d;
        Matrix matrix = ARPACK.syev(adjacencyList2.toMatrix(), ARPACK.SymmOption.SM, Math.min(10 * (i + 1), numVertices - 1)).Vr;
        double[][] dArr2 = new double[numVertices][i];
        int i5 = i;
        while (true) {
            i5--;
            if (i5 < 0) {
                break;
            }
            int ncol = (matrix.ncol() - i5) - 2;
            for (int i6 = 0; i6 < numVertices; i6++) {
                double d2 = matrix.get(i6, ncol);
                dArr2[i6][i5] = d2;
                double abs = Math.abs(d2);
                if (abs > d) {
                    d = abs;
                }
            }
        }
        double d3 = 10.0d / d;
        GaussianDistribution gaussianDistribution = new GaussianDistribution(CMAESOptimizer.DEFAULT_STOPFITNESS, 1.0E-4d);
        for (int i7 = 0; i7 < numVertices; i7++) {
            for (int i8 = 0; i8 < i; i8++) {
                dArr2[i7][i8] = (dArr2[i7][i8] * d3) + gaussianDistribution.rand();
            }
        }
        double[] colMax = MathEx.colMax(dArr2);
        double[] colMin = MathEx.colMin(dArr2);
        double[] dArr3 = new double[i];
        for (int i9 = 0; i9 < i; i9++) {
            dArr3[i9] = colMax[i9] - colMin[i9];
        }
        for (int i10 = 0; i10 < numVertices; i10++) {
            for (int i11 = 0; i11 < i; i11++) {
                dArr2[i10][i11] = (10.0d * (dArr2[i10][i11] - colMin[i11])) / dArr3[i11];
            }
        }
        return dArr2;
    }

    private static void optimizeLayout(double[][] dArr, double[] dArr2, SparseMatrix sparseMatrix, int i, double d, int i2, double d2) {
        int length = dArr.length;
        int length2 = dArr[0].length;
        double d3 = dArr2[0];
        double d4 = dArr2[1];
        double d5 = d;
        SparseMatrix m1435clone = sparseMatrix.m1435clone();
        m1435clone.nonzeros().forEach(entry -> {
            entry.update(entry.x / i2);
        });
        SparseMatrix m1435clone2 = m1435clone.m1435clone();
        SparseMatrix m1435clone3 = sparseMatrix.m1435clone();
        for (int i3 = 1; i3 <= i; i3++) {
            Iterator<SparseMatrix.Entry> it = m1435clone3.iterator();
            while (it.hasNext()) {
                SparseMatrix.Entry next = it.next();
                if (next.x > CMAESOptimizer.DEFAULT_STOPFITNESS && next.x <= i3) {
                    int i4 = next.i;
                    int i5 = next.j;
                    int i6 = next.index;
                    double[] dArr3 = dArr[i4];
                    double[] dArr4 = dArr[i5];
                    double squaredDistance = MathEx.squaredDistance(dArr3, dArr4);
                    if (squaredDistance > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                        double pow = ((((-2.0d) * d3) * d4) * Math.pow(squaredDistance, d4 - 1.0d)) / ((d3 * Math.pow(squaredDistance, d4)) + 1.0d);
                        for (int i7 = 0; i7 < length2; i7++) {
                            double clamp = clamp(pow * (dArr3[i7] - dArr4[i7]));
                            int i8 = i7;
                            dArr3[i8] = dArr3[i8] + (clamp * d5);
                            int i9 = i7;
                            dArr4[i9] = dArr4[i9] - (clamp * d5);
                        }
                    }
                    next.update(next.x + sparseMatrix.get(i6));
                    int i10 = (int) ((i3 - m1435clone2.get(i6)) / m1435clone.get(i6));
                    for (int i11 = 0; i11 < i10; i11++) {
                        int randomInt = MathEx.randomInt(length);
                        if (i4 != randomInt) {
                            double[] dArr5 = dArr[randomInt];
                            double squaredDistance2 = MathEx.squaredDistance(dArr3, dArr5);
                            double pow2 = squaredDistance2 > CMAESOptimizer.DEFAULT_STOPFITNESS ? ((2.0d * d2) * d4) / ((0.001d + squaredDistance2) * ((d3 * Math.pow(squaredDistance2, d4)) + 1.0d)) : 0.0d;
                            for (int i12 = 0; i12 < length2; i12++) {
                                double d6 = 4.0d;
                                if (pow2 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                                    d6 = clamp(pow2 * (dArr3[i12] - dArr5[i12]));
                                }
                                int i13 = i12;
                                dArr3[i13] = dArr3[i13] + (d6 * d5);
                            }
                        }
                    }
                    m1435clone2.set(i6, m1435clone2.get(i6) + (m1435clone.get(i6) * i10));
                }
            }
            logger.info(String.format("The learning rate at %3d iterations: %.5f", Integer.valueOf(i3), Double.valueOf(d5)));
            d5 = d * (1.0d - (i3 / i));
        }
    }

    private static SparseMatrix computeEpochPerSample(SparseMatrix sparseMatrix, int i) {
        double orElse = sparseMatrix.nonzeros().mapToDouble(entry -> {
            return entry.x;
        }).max().orElse(CMAESOptimizer.DEFAULT_STOPFITNESS);
        double d = orElse / i;
        sparseMatrix.nonzeros().forEach(entry2 -> {
            if (entry2.x < d) {
                entry2.update(CMAESOptimizer.DEFAULT_STOPFITNESS);
            } else {
                entry2.update(orElse / entry2.x);
            }
        });
        return sparseMatrix;
    }

    private static double clamp(double d) {
        return Math.min(4.0d, Math.max(d, -4.0d));
    }
}
