/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import smile.classification.AbstractClassifier;
import smile.math.MathEx;
import smile.stat.distribution.Distribution;
import smile.util.IntSet;

public class NaiveBayes
extends AbstractClassifier<double[]> {
    private static final long serialVersionUID = 2L;
    private final int k;
    private final int p;
    private final double[] priori;
    private final Distribution[][] prob;

    public NaiveBayes(double[] priori, Distribution[][] condprob) {
        this(priori, condprob, IntSet.of(priori.length));
    }

    public NaiveBayes(double[] priori, Distribution[][] condprob, IntSet labels) {
        super(labels);
        if (priori.length != condprob.length) {
            throw new IllegalArgumentException("The number of priori probabilities and that of the classes are not same.");
        }
        double sum = 0.0;
        for (double pr : priori) {
            if (pr <= 0.0 || pr >= 1.0) {
                throw new IllegalArgumentException("Invalid priori probability: " + pr);
            }
            sum += pr;
        }
        if (Math.abs(sum - 1.0) > 1.0E-5) {
            throw new IllegalArgumentException("The sum of priori probabilities is not one: " + sum);
        }
        this.k = priori.length;
        this.p = condprob[0].length;
        this.priori = priori;
        this.prob = condprob;
    }

    public double[] priori() {
        return this.priori;
    }

    @Override
    public int predict(double[] x) {
        return this.predict(x, new double[this.k]);
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public int predict(double[] x, double[] posteriori) {
        int i;
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d", x.length));
        }
        for (int i2 = 0; i2 < this.k; ++i2) {
            double logprob = Math.log(this.priori[i2]);
            for (int j = 0; j < this.p; ++j) {
                logprob += this.prob[i2][j].logp(x[j]);
            }
            posteriori[i2] = logprob;
        }
        double Z = 0.0;
        double max = MathEx.max(posteriori);
        for (i = 0; i < this.k; ++i) {
            posteriori[i] = Math.exp(posteriori[i] - max);
            Z += posteriori[i];
        }
        i = 0;
        while (i < this.k) {
            int n = i++;
            posteriori[n] = posteriori[n] / Z;
        }
        return this.classes.valueOf(MathEx.whichMax(posteriori));
    }
}

