package smile.clustering;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;

/* loaded from: input_file:smile/clustering/KMeans.class */
public class KMeans extends CentroidClustering<double[], double[]> {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) KMeans.class);

    public KMeans(double d, double[][] dArr, int[] iArr) {
        super(d, dArr, iArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // smile.clustering.CentroidClustering
    public double distance(double[] dArr, double[] dArr2) {
        return MathEx.squaredDistance(dArr, dArr2);
    }

    public static KMeans fit(double[][] dArr, int i) {
        return fit(dArr, i, 100, 1.0E-4d);
    }

    public static KMeans fit(double[][] dArr, int i, int i2, double d) {
        return fit(new BBDTree(dArr), dArr, i, i2, d);
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [double[], java.lang.Object[]] */
    public static KMeans fit(BBDTree bBDTree, double[][] dArr, int i, int i2, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        int[] iArr = new int[length];
        double sum = MathEx.sum(seed(dArr, new double[i], iArr, MathEx::squaredDistance));
        logger.info(String.format("Distortion after initialization: %.4f", Double.valueOf(sum)));
        int[] iArr2 = new int[i];
        double[][] dArr2 = new double[i][length2];
        updateCentroids(dArr2, dArr, iArr, iArr2);
        double[][] dArr3 = new double[i][length2];
        double d2 = Double.MAX_VALUE;
        for (int i3 = 1; i3 <= i2 && d2 > d; i3++) {
            double clustering = bBDTree.clustering(dArr2, dArr3, iArr2, iArr);
            logger.info(String.format("Distortion after %3d iterations: %.4f", Integer.valueOf(i3), Double.valueOf(clustering)));
            d2 = sum - clustering;
            sum = clustering;
        }
        return new KMeans(sum, dArr2, iArr);
    }

    public static KMeans lloyd(double[][] dArr, int i) {
        return lloyd(dArr, i, 100, 1.0E-4d);
    }

    /* JADX WARN: Type inference failed for: r0v10, types: [double[], java.lang.Object[]] */
    public static KMeans lloyd(double[][] dArr, int i, int i2, double d) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid number of clusters: " + i);
        }
        if (i2 <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i2);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        int[] iArr = new int[length];
        double sum = MathEx.sum(seed(dArr, new double[i], iArr, MathEx::squaredDistanceWithMissingValues));
        logger.info(String.format("Distortion after initialization: %.4f", Double.valueOf(sum)));
        int[] iArr2 = new int[i];
        double[][] dArr2 = new double[i][length2];
        int[][] iArr3 = new int[i][length2];
        double d2 = Double.MAX_VALUE;
        for (int i3 = 1; i3 <= i2 && d2 > d; i3++) {
            updateCentroidsWithMissingValues(dArr2, dArr, iArr, iArr2, iArr3);
            double assign = assign(iArr, dArr, dArr2, MathEx::squaredDistanceWithMissingValues);
            logger.info(String.format("Distortion after %3d iterations: %.4f", Integer.valueOf(i3), Double.valueOf(assign)));
            d2 = sum - assign;
            sum = assign;
        }
        if (d2 > d) {
            updateCentroidsWithMissingValues(dArr2, dArr, iArr, iArr2, iArr3);
        }
        return new KMeans(sum, dArr2, iArr) { // from class: smile.clustering.KMeans.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // smile.clustering.KMeans, smile.clustering.CentroidClustering
            public double distance(double[] dArr3, double[] dArr4) {
                return MathEx.squaredDistanceWithMissingValues(dArr3, dArr4);
            }
        };
    }
}
