/*
 * Decompiled with CFR 0.152.
 */
package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.clustering.CentroidClustering;
import smile.math.MathEx;
import smile.math.distance.HammingDistance;
import smile.util.IntSet;

public class KModes
extends CentroidClustering<int[], int[]> {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(KModes.class);

    public KModes(double distortion, int[][] centroids, int[] y) {
        super(distortion, (T[])centroids, y);
    }

    @Override
    protected double distance(int[] x, int[] y) {
        return HammingDistance.d(x, y);
    }

    public static KModes fit(int[][] data, int k) {
        return KModes.fit(data, k, 100);
    }

    public static KModes fit(int[][] data, int k, int maxIter) {
        if (k < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + k);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = data.length;
        int d = data[0].length;
        Codec[] codec = (Codec[])IntStream.range(0, d).parallel().mapToObj(j -> {
            int[] x = new int[n];
            for (int i = 0; i < n; ++i) {
                x[i] = data[i][j];
            }
            return new Codec(x);
        }).toArray(Codec[]::new);
        int[] y = new int[n];
        int[][] medoids = new int[k][];
        int[][] centroids = new int[k][d];
        double distortion = MathEx.sum(KModes.seed(data, medoids, y, HammingDistance::d));
        logger.info(String.format("Distortion after initialization: %d", (int)distortion));
        double diff = 2.147483647E9;
        for (int iter = 1; iter <= maxIter && diff > 0.0; ++iter) {
            KModes.updateCentroids(centroids, data, y, codec);
            double wcss = KModes.assign(y, data, centroids, HammingDistance::d);
            logger.info(String.format("Distortion after %3d iterations: %d", iter, (int)wcss));
            diff = distortion - wcss;
            distortion = wcss;
        }
        if (diff > 0.0) {
            KModes.updateCentroids(centroids, data, y, codec);
        }
        return new KModes(distortion, centroids, y);
    }

    private static void updateCentroids(int[][] centroids, int[][] data, int[] y, Codec[] codec) {
        int n = data.length;
        int k = centroids.length;
        int d = centroids[0].length;
        IntStream.range(0, k).parallel().forEach(cluster -> {
            int[] centroid = centroids[cluster];
            for (int j = 0; j < d; ++j) {
                if (codec[j].k <= 1) continue;
                int[] count = new int[codec[j].k];
                int[] x = codec[j].x;
                for (int i = 0; i < n; ++i) {
                    if (y[i] != cluster) continue;
                    int n2 = x[i];
                    count[n2] = count[n2] + 1;
                }
                centroid[j] = codec[j].valueOf(MathEx.whichMax(count));
            }
        });
    }

    private static class Codec {
        public final int k;
        public final int[] x;
        public final IntSet encoder;

        public Codec(int[] x) {
            int[] y = MathEx.unique(x);
            Arrays.sort(y);
            this.x = x;
            this.k = y.length;
            this.encoder = new IntSet(y);
            if (y[0] != 0 || y[this.k - 1] != this.k - 1) {
                int n = x.length;
                for (int i = 0; i < n; ++i) {
                    x[i] = this.encoder.indexOf(x[i]);
                }
            }
        }

        public int valueOf(int i) {
            return this.encoder.valueOf(i);
        }
    }
}

