/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.util.SparseArray;
import smile.util.Strings;

public class SIB
extends CentroidClustering<double[], SparseArray> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(SIB.class);

    public SIB(double distortion, double[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    protected double distance(double[] x, SparseArray y) {
        return MathEx.JensenShannonDivergence(x, y);
    }

    public static SIB fit(SparseArray[] data, int k) {
        return SIB.fit(data, k, 100);
    }

    public static SIB fit(SparseArray[] data, int k, int maxIter) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid parameter k = " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = 1 + Arrays.stream(data).flatMap(SparseArray::stream).mapToInt(e -> e.i).max().orElse(0);
        int[] y = new int[n];
        SparseArray[] medoids = new SparseArray[k];
        double distortion = MathEx.sum(SIB.seed(data, medoids, y, MathEx::JensenShannonDivergence));
        logger.info(String.format("Distortion after initialization: %.4f", distortion));
        int[] size = new int[k];
        double[][] centroids = new double[k][d];
        IntStream.range(0, k).parallel().forEach(cluster -> {
            for (int i = 0; i < n; ++i) {
                if (y[i] != cluster) continue;
                int n2 = cluster;
                size[n2] = size[n2] + 1;
                for (SparseArray.Entry e : data[i]) {
                    double[] dArray = centroids[cluster];
                    int n3 = e.i;
                    dArray[n3] = dArray[n3] + e.x;
                }
            }
            int j = 0;
            while (j < d) {
                double[] dArray = centroids[cluster];
                int n4 = j++;
                dArray[n4] = dArray[n4] / (double)size[cluster];
            }
        });
        int reassignment = n;
        for (int iter = 1; iter <= maxIter && reassignment > 0; ++iter) {
            reassignment = 0;
            for (int i2 = 0; i2 < n; ++i2) {
                int c = y[i2];
                double nearest = Double.MAX_VALUE;
                for (int j = 0; j < k; ++j) {
                    double divergence = MathEx.JensenShannonDivergence(data[i2], centroids[j]);
                    if (!(nearest > divergence)) continue;
                    nearest = divergence;
                    c = j;
                }
                if (c == y[i2]) continue;
                int o = y[i2];
                int j = 0;
                while (j < d) {
                    double[] dArray = centroids[c];
                    int n2 = j;
                    dArray[n2] = dArray[n2] * (double)size[c];
                    double[] dArray2 = centroids[o];
                    int n3 = j++;
                    dArray2[n3] = dArray2[n3] * (double)size[o];
                }
                for (SparseArray.Entry e2 : data[i2]) {
                    int j2 = e2.i;
                    double p = e2.x;
                    double[] dArray = centroids[c];
                    int n4 = j2;
                    dArray[n4] = dArray[n4] + p;
                    double[] dArray3 = centroids[o];
                    int n5 = j2;
                    dArray3[n5] = dArray3[n5] - p;
                    if (!(centroids[o][j2] < 0.0)) continue;
                    centroids[o][j2] = 0.0;
                }
                int n6 = o;
                size[n6] = size[n6] - 1;
                int n7 = c;
                size[n7] = size[n7] + 1;
                j = 0;
                while (j < d) {
                    double[] dArray = centroids[c];
                    int n8 = j++;
                    dArray[n8] = dArray[n8] / (double)size[c];
                }
                if (size[o] > 0) {
                    j = 0;
                    while (j < d) {
                        double[] dArray = centroids[o];
                        int n9 = j++;
                        dArray[n9] = dArray[n9] / (double)size[o];
                    }
                }
                y[i2] = c;
                ++reassignment;
            }
            logger.info("Assignments of {} iterations: {}", (Object)Strings.ordinal(iter), (Object)reassignment);
        }
        distortion = IntStream.range(0, n).parallel().mapToDouble(i -> MathEx.JensenShannonDivergence(data[i], centroids[y[i]])).sum();
        logger.info(String.format("Final distortion: %.4f", distortion));
        return new SIB(distortion, centroids, y);
    }
}

