/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.Classifier;

public class PlattScaling
implements Serializable {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(PlattScaling.class);
    private final double alpha;
    private final double beta;

    public PlattScaling(double alpha, double beta) {
        this.alpha = alpha;
        this.beta = beta;
    }

    public double scale(double y) {
        double fApB = y * this.alpha + this.beta;
        if (fApB >= 0.0) {
            return Math.exp(-fApB) / (1.0 + Math.exp(-fApB));
        }
        return 1.0 / (1.0 + Math.exp(fApB));
    }

    public static PlattScaling fit(double[] scores, int[] y) {
        return PlattScaling.fit(scores, y, 100);
    }

    public static PlattScaling fit(double[] scores, int[] y, int maxIters) {
        int iter;
        int i;
        int l = scores.length;
        double prior1 = 0.0;
        double prior0 = 0.0;
        for (i = 0; i < l; ++i) {
            if (y[i] > 0) {
                prior1 += 1.0;
                continue;
            }
            prior0 += 1.0;
        }
        double minStep = 1.0E-10;
        double sigma = 1.0E-12;
        double eps = 1.0E-5;
        double hiTarget = (prior1 + 1.0) / (prior1 + 2.0);
        double loTarget = 1.0 / (prior0 + 2.0);
        double[] t2 = new double[l];
        double alpha = 0.0;
        double beta = Math.log((prior0 + 1.0) / (prior1 + 1.0));
        double fval = 0.0;
        for (i = 0; i < l; ++i) {
            t2[i] = y[i] > 0 ? hiTarget : loTarget;
            double fApB = scores[i] * alpha + beta;
            if (fApB >= 0.0) {
                fval += t2[i] * fApB + Math.log(1.0 + Math.exp(-fApB));
                continue;
            }
            fval += (t2[i] - 1.0) * fApB + Math.log(1.0 + Math.exp(fApB));
        }
        for (iter = 0; iter < maxIters; ++iter) {
            double stepSize;
            double h11 = sigma;
            double h22 = sigma;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            for (i = 0; i < l; ++i) {
                double q;
                double p;
                double fApB = scores[i] * alpha + beta;
                if (fApB >= 0.0) {
                    p = Math.exp(-fApB) / (1.0 + Math.exp(-fApB));
                    q = 1.0 / (1.0 + Math.exp(-fApB));
                } else {
                    p = 1.0 / (1.0 + Math.exp(fApB));
                    q = Math.exp(fApB) / (1.0 + Math.exp(fApB));
                }
                double d2 = p * q;
                h11 += scores[i] * scores[i] * d2;
                h22 += d2;
                h21 += scores[i] * d2;
                double d1 = t2[i] - p;
                g1 += scores[i] * d1;
                g2 += d1;
            }
            if (Math.abs(g1) < eps && Math.abs(g2) < eps) break;
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;
            for (stepSize = 1.0; stepSize >= minStep; stepSize /= 2.0) {
                double newA = alpha + stepSize * dA;
                double newB = beta + stepSize * dB;
                double newf = 0.0;
                for (i = 0; i < l; ++i) {
                    double fApB = scores[i] * newA + newB;
                    if (fApB >= 0.0) {
                        newf += t2[i] * fApB + Math.log(1.0 + Math.exp(-fApB));
                        continue;
                    }
                    newf += (t2[i] - 1.0) * fApB + Math.log(1.0 + Math.exp(fApB));
                }
                if (!(newf < fval + 1.0E-4 * stepSize * gd)) continue;
                alpha = newA;
                beta = newB;
                fval = newf;
                break;
            }
            if (!(stepSize < minStep)) continue;
            logger.error("Line search fails.");
            break;
        }
        if (iter >= maxIters) {
            logger.warn("Reaches maximal iterations");
        }
        return new PlattScaling(alpha, beta);
    }

    public static <T> PlattScaling fit(Classifier<T> model, T[] x, int[] y) {
        int n = y.length;
        double[] scores = new double[n];
        for (int i = 0; i < n; ++i) {
            scores[i] = model.score(x[i]);
        }
        return PlattScaling.fit(scores, y);
    }
}

