/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.DBSCAN;
import smile.clustering.KMeans;
import smile.clustering.PartitionClustering;
import smile.math.MathEx;

public class DENCLUE
extends PartitionClustering {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(DENCLUE.class);
    private final double tol;
    private final double sigma;
    public final double[][] attractors;
    private final double[] radius;
    private final double[][] samples;

    public DENCLUE(int k, double[][] attractors, double[] radius, double[][] samples, double sigma, int[] y, double tol) {
        super(k, y);
        this.attractors = attractors;
        this.radius = radius;
        this.samples = samples;
        this.sigma = sigma;
        this.tol = tol;
    }

    public static DENCLUE fit(double[][] data, double sigma, int m) {
        int n = data.length;
        return DENCLUE.fit(data, sigma, m, 0.01, Math.max(10, n / 200));
    }

    public static DENCLUE fit(double[][] data, double sigma, int m, double tol, int minPts) {
        if (sigma <= 0.0) {
            throw new IllegalArgumentException("Invalid standard deviation of Gaussian kernel: " + sigma);
        }
        if (m <= 0 || m > data.length) {
            throw new IllegalArgumentException("Invalid number of selected samples: " + m);
        }
        logger.info("Select {} samples by k-means", (Object)m);
        KMeans kmeans = KMeans.fit(data, m);
        double[][] samples = (double[][])kmeans.centroids;
        int n = data.length;
        int d = data[0].length;
        double[][] attractors = new double[n][d];
        double[][] steps = new double[n][2];
        logger.info("Hill-climbing of density function for each observation");
        IntStream.range(0, n).parallel().forEach(i -> DENCLUE.climb(data[i], attractors[i], steps[i], samples, sigma, tol));
        if (Arrays.stream(attractors).flatMapToDouble(Arrays::stream).anyMatch(ai -> !Double.isFinite(ai))) {
            throw new IllegalStateException("Attractors contains NaN/infinity. sigma is likely too small.");
        }
        double[] radius = Arrays.stream(steps).mapToDouble(step -> step[0] + step[1]).toArray();
        double r = MathEx.mean(radius);
        if (!Double.isFinite(r)) {
            throw new IllegalStateException("The average of last steps of hill-climbing is NaN/infinity. sigma is likely too small.");
        }
        logger.info("Clustering attractors with DBSCAN (radius = {})", (Object)r);
        DBSCAN<double[]> dbscan = DBSCAN.fit(attractors, minPts, r);
        return new DENCLUE(dbscan.k, attractors, radius, samples, sigma, dbscan.y, tol);
    }

    public int predict(double[] x) {
        int d = this.attractors[0].length;
        if (x.length != d) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, d));
        }
        double[] attractor = new double[d];
        double[] step = new double[2];
        DENCLUE.climb(x, attractor, step, this.samples, this.sigma, this.tol);
        double r = step[0] + step[1];
        for (int i = 0; i < this.attractors.length; ++i) {
            if (!(MathEx.distance(this.attractors[i], attractor) < this.radius[i] + r)) continue;
            return this.y[i];
        }
        return Integer.MAX_VALUE;
    }

    private static double climb(double[] x, double[] attractor, double[] step, double[][] samples, double sigma, double tol) {
        int m = samples.length;
        int d = x.length;
        int k = step.length;
        double p = 1.0;
        double h = Math.pow(Math.PI * 2 * sigma, (double)d / 2.0);
        double gamma = -0.5 / (sigma * sigma);
        x = (double[])x.clone();
        double[] w = new double[m];
        double diff = Double.MAX_VALUE;
        for (int iter = 0; iter < k || diff > tol; ++iter) {
            int i;
            for (i = 0; i < m; ++i) {
                w[i] = Math.exp(gamma * MathEx.squaredDistance(x, samples[i]));
            }
            Arrays.fill(attractor, 0.0);
            for (i = 0; i < m; ++i) {
                double wi = w[i];
                double[] xi = samples[i];
                for (int j = 0; j < d; ++j) {
                    int n = j;
                    attractor[n] = attractor[n] + wi * xi[j];
                }
            }
            double W = MathEx.sum(w);
            int j = 0;
            while (j < d) {
                int n = j++;
                attractor[n] = attractor[n] / W;
            }
            double prob = W / ((double)m * h);
            diff = Math.abs(prob - p) / p;
            p = prob;
            step[iter % k] = MathEx.distance(attractor, x);
            System.arraycopy(attractor, 0, x, 0, d);
        }
        return p;
    }
}

