/*
 * Decompiled with CFR 0.152.
 */
package smile.validation.metric;

import java.util.Arrays;
import smile.math.MathEx;
import smile.validation.metric.ClusteringMetric;
import smile.validation.metric.ContingencyTable;
import smile.validation.metric.MutualInformation;

public class NormalizedMutualInformation
implements ClusteringMetric {
    private static final long serialVersionUID = 2L;
    public static final NormalizedMutualInformation JOINT = new NormalizedMutualInformation(Method.JOINT);
    public static final NormalizedMutualInformation MAX = new NormalizedMutualInformation(Method.MAX);
    public static final NormalizedMutualInformation MIN = new NormalizedMutualInformation(Method.MIN);
    public static final NormalizedMutualInformation SUM = new NormalizedMutualInformation(Method.SUM);
    public static final NormalizedMutualInformation SQRT = new NormalizedMutualInformation(Method.SQRT);
    private final Method method;

    public NormalizedMutualInformation(Method method) {
        this.method = method;
    }

    @Override
    public double score(int[] y1, int[] y2) {
        switch (this.method) {
            case JOINT: {
                return NormalizedMutualInformation.joint(y1, y2);
            }
            case MAX: {
                return NormalizedMutualInformation.max(y1, y2);
            }
            case MIN: {
                return NormalizedMutualInformation.min(y1, y2);
            }
            case SUM: {
                return NormalizedMutualInformation.sum(y1, y2);
            }
            case SQRT: {
                return NormalizedMutualInformation.sqrt(y1, y2);
            }
        }
        throw new IllegalStateException("Unknown normalization method: " + (Object)((Object)this.method));
    }

    public static double joint(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        int n1 = p1.length;
        int n2 = p2.length;
        int[][] count = contingency.table;
        double H = 0.0;
        for (int i = 0; i < n1; ++i) {
            for (int j = 0; j < n2; ++j) {
                if (count[i][j] <= 0) continue;
                double p = (double)count[i][j] / n;
                H -= p * Math.log(p);
            }
        }
        return I / H;
    }

    public static double max(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy(p1);
        double h2 = MathEx.entropy(p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        return I / Math.max(h1, h2);
    }

    public static double sum(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy(p1);
        double h2 = MathEx.entropy(p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        return 2.0 * I / (h1 + h2);
    }

    public static double sqrt(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy(p1);
        double h2 = MathEx.entropy(p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        return I / Math.sqrt(h1 * h2);
    }

    public static double min(int[] y1, int[] y2) {
        ContingencyTable contingency = new ContingencyTable(y1, y2);
        double n = contingency.n;
        double[] p1 = Arrays.stream(contingency.a).mapToDouble(a -> (double)a / n).toArray();
        double[] p2 = Arrays.stream(contingency.b).mapToDouble(b -> (double)b / n).toArray();
        double h1 = MathEx.entropy(p1);
        double h2 = MathEx.entropy(p2);
        double I = MutualInformation.of(contingency.n, p1, p2, contingency.table);
        return I / Math.min(h1, h2);
    }

    public String toString() {
        return String.format("NormalizedMutualInformation(%s)", new Object[]{this.method});
    }

    public static enum Method {
        JOINT,
        MAX,
        MIN,
        SUM,
        SQRT;

    }
}

