/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import smile.classification.AbstractClassifier;
import smile.classification.ClassLabels;
import smile.data.SparseDataset;
import smile.math.BFGS;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;
import smile.util.IntSet;
import smile.util.SparseArray;
import smile.validation.ModelSelection;

public abstract class SparseLogisticRegression
extends AbstractClassifier<SparseArray> {
    private static final long serialVersionUID = 2L;
    int p;
    int k;
    double L;
    double lambda;
    double eta = 0.1;

    public SparseLogisticRegression(int p, double L, double lambda, IntSet labels) {
        super(labels);
        this.k = labels.size();
        this.p = p;
        this.L = L;
        this.lambda = lambda;
    }

    public static Binomial binomial(SparseDataset x, int[] y) {
        return SparseLogisticRegression.binomial(x, y, new Properties());
    }

    public static Binomial binomial(SparseDataset x, int[] y, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.logistic.lambda", "0.1"));
        double tol = Double.parseDouble(params.getProperty("smile.logistic.tolerance", "1E-5"));
        int maxIter = Integer.parseInt(params.getProperty("smile.logistic.iterations", "500"));
        return SparseLogisticRegression.binomial(x, y, lambda, tol, maxIter);
    }

    public static Binomial binomial(SparseDataset x, int[] y, double lambda, double tol, int maxIter) {
        if (x.size() != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.size(), y.length));
        }
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int p = x.ncol();
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        y = codec.y;
        if (k != 2) {
            throw new IllegalArgumentException("Fits binomial model on multi-class data.");
        }
        BinomialObjective objective = new BinomialObjective(x, y, lambda);
        double[] w = new double[p + 1];
        double L = -BFGS.minimize(objective, 5, w, tol, maxIter);
        Binomial model = new Binomial(w, L, lambda, codec.classes);
        model.setLearningRate(0.1 / (double)x.size());
        return model;
    }

    public static Multinomial multinomial(SparseDataset x, int[] y) {
        return SparseLogisticRegression.multinomial(x, y, new Properties());
    }

    public static Multinomial multinomial(SparseDataset x, int[] y, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.logistic.lambda", "0.1"));
        double tol = Double.parseDouble(params.getProperty("smile.logistic.tolerance", "1E-5"));
        int maxIter = Integer.parseInt(params.getProperty("smile.logistic.iterations", "500"));
        return SparseLogisticRegression.multinomial(x, y, lambda, tol, maxIter);
    }

    public static Multinomial multinomial(SparseDataset x, int[] y, double lambda, double tol, int maxIter) {
        if (x.size() != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.size(), y.length));
        }
        if (lambda < 0.0) {
            throw new IllegalArgumentException("Invalid regularization factor: " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int p = x.ncol();
        ClassLabels codec = ClassLabels.fit(y);
        int k = codec.k;
        y = codec.y;
        if (k <= 2) {
            throw new IllegalArgumentException("Fits multinomial model on binary class data.");
        }
        MultinomialObjective objective = new MultinomialObjective(x, y, k, lambda);
        double[] w = new double[(k - 1) * (p + 1)];
        double L = -BFGS.minimize(objective, 5, w, tol, maxIter);
        double[][] W = new double[k - 1][p + 1];
        int l = 0;
        for (int i = 0; i < k - 1; ++i) {
            int j = 0;
            while (j <= p) {
                W[i][j] = w[l];
                ++j;
                ++l;
            }
        }
        Multinomial model = new Multinomial(W, L, lambda, codec.classes);
        model.setLearningRate(0.1 / (double)x.size());
        return model;
    }

    public static SparseLogisticRegression fit(SparseDataset x, int[] y) {
        return SparseLogisticRegression.fit(x, y, new Properties());
    }

    public static SparseLogisticRegression fit(SparseDataset x, int[] y, Properties params) {
        double lambda = Double.parseDouble(params.getProperty("smile.logistic.lambda", "0.1"));
        double tol = Double.parseDouble(params.getProperty("smile.logistic.tolerance", "1E-5"));
        int maxIter = Integer.parseInt(params.getProperty("smile.logistic.iterations", "500"));
        return SparseLogisticRegression.fit(x, y, lambda, tol, maxIter);
    }

    public static SparseLogisticRegression fit(SparseDataset x, int[] y, double lambda, double tol, int maxIter) {
        ClassLabels codec = ClassLabels.fit(y);
        if (codec.k == 2) {
            return SparseLogisticRegression.binomial(x, y, lambda, tol, maxIter);
        }
        return SparseLogisticRegression.multinomial(x, y, lambda, tol, maxIter);
    }

    private static double dot(SparseArray x, double[] w) {
        double dot = w[w.length - 1];
        for (SparseArray.Entry e : x) {
            dot += e.x * w[e.i];
        }
        return dot;
    }

    private static double dot(SparseArray x, double[] w, int j, int p) {
        int pos = j * (p + 1);
        double dot = w[pos + p];
        for (SparseArray.Entry e : x) {
            dot += e.x * w[pos + e.i];
        }
        return dot;
    }

    @Override
    public boolean soft() {
        return true;
    }

    @Override
    public boolean online() {
        return true;
    }

    public void setLearningRate(double rate) {
        if (rate <= 0.0) {
            throw new IllegalArgumentException("Invalid learning rate: " + rate);
        }
        this.eta = rate;
    }

    public double getLearningRate() {
        return this.eta;
    }

    public double loglikelihood() {
        return this.L;
    }

    public double AIC() {
        return ModelSelection.AIC(this.L, (this.k - 1) * (this.p + 1));
    }

    public static class Binomial
    extends SparseLogisticRegression {
        private final double[] w;

        public Binomial(double[] w, double L, double lambda, IntSet labels) {
            super(w.length - 1, L, lambda, labels);
            this.w = w;
        }

        public double[] coefficients() {
            return this.w;
        }

        @Override
        public double score(SparseArray x) {
            return 1.0 / (1.0 + Math.exp(-SparseLogisticRegression.dot(x, this.w)));
        }

        @Override
        public int predict(SparseArray x) {
            double f = 1.0 / (1.0 + Math.exp(-SparseLogisticRegression.dot(x, this.w)));
            return this.classes.valueOf(f < 0.5 ? 0 : 1);
        }

        @Override
        public int predict(SparseArray x, double[] posteriori) {
            if (posteriori.length != this.k) {
                throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
            }
            double f = 1.0 / (1.0 + Math.exp(-SparseLogisticRegression.dot(x, this.w)));
            posteriori[0] = 1.0 - f;
            posteriori[1] = f;
            return this.classes.valueOf(f < 0.5 ? 0 : 1);
        }

        @Override
        public void update(SparseArray x, int y) {
            y = this.classes.indexOf(y);
            double wx = SparseLogisticRegression.dot(x, this.w);
            double err = (double)y - MathEx.sigmoid(wx);
            int n = this.p;
            this.w[n] = this.w[n] + this.eta * err;
            for (SparseArray.Entry e : x) {
                int n2 = e.i;
                this.w[n2] = this.w[n2] + this.eta * err * e.x;
            }
            if (this.lambda > 0.0) {
                for (int j = 0; j < this.p; ++j) {
                    int n3 = j;
                    this.w[n3] = this.w[n3] - this.eta * this.lambda * this.w[j];
                }
            }
        }
    }

    static class BinomialObjective
    implements DifferentiableMultivariateFunction {
        SparseDataset x;
        int[] y;
        int p;
        double lambda;
        int partitionSize;
        int partitions;
        double[][] gradients;

        BinomialObjective(SparseDataset x, int[] y, double lambda) {
            this.x = x;
            this.y = y;
            this.lambda = lambda;
            this.p = x.ncol();
            this.partitionSize = Integer.parseInt(System.getProperty("smile.data.partition.size", "1000"));
            this.partitions = x.size() / this.partitionSize + (x.size() % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][this.p + 1];
        }

        @Override
        public double f(double[] w) {
            double f = IntStream.range(0, this.x.size()).parallel().mapToDouble(i -> {
                double wx = SparseLogisticRegression.dot((SparseArray)this.x.get(i), w);
                return MathEx.log1pe(wx) - (double)this.y[i] * wx;
            }).sum();
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i2 = 0; i2 < this.p; ++i2) {
                    wnorm += w[i2] * w[i2];
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        @Override
        public double g(double[] w, double[] g) {
            double f = IntStream.range(0, this.partitions).parallel().mapToDouble(r -> {
                double[] gradient = this.gradients[r];
                Arrays.fill(gradient, 0.0);
                int begin = r * this.partitionSize;
                int end = (r + 1) * this.partitionSize;
                if (end > this.x.size()) {
                    end = this.x.size();
                }
                return IntStream.range(begin, end).sequential().mapToDouble(i -> {
                    SparseArray xi = (SparseArray)this.x.get(i);
                    double wx = SparseLogisticRegression.dot(xi, w);
                    double err = (double)this.y[i] - MathEx.sigmoid(wx);
                    for (SparseArray.Entry e : xi) {
                        int n = e.i;
                        gradient[n] = gradient[n] - err * e.x;
                    }
                    int n = this.p;
                    gradient[n] = gradient[n] - err;
                    return MathEx.log1pe(wx) - (double)this.y[i] * wx;
                }).sum();
            }).sum();
            Arrays.fill(g, 0.0);
            for (double[] gradient : this.gradients) {
                for (int i = 0; i < g.length; ++i) {
                    int n = i;
                    g[n] = g[n] + gradient[i];
                }
            }
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.p; ++i) {
                    wnorm += w[i] * w[i];
                    int n = i;
                    g[n] = g[n] + this.lambda * w[i];
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }
    }

    public static class Multinomial
    extends SparseLogisticRegression {
        private final double[][] w;

        public Multinomial(double[][] w, double L, double lambda, IntSet labels) {
            super(w[0].length - 1, L, lambda, labels);
            this.w = w;
        }

        public double[][] coefficients() {
            return this.w;
        }

        @Override
        public int predict(SparseArray x) {
            return this.predict(x, new double[this.k]);
        }

        @Override
        public int predict(SparseArray x, double[] posteriori) {
            if (posteriori.length != this.k) {
                throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", posteriori.length, this.k));
            }
            posteriori[this.k - 1] = 0.0;
            for (int i = 0; i < this.k - 1; ++i) {
                posteriori[i] = SparseLogisticRegression.dot(x, this.w[i]);
            }
            MathEx.softmax(posteriori);
            return this.classes.valueOf(MathEx.whichMax(posteriori));
        }

        @Override
        public void update(SparseArray x, int y) {
            y = this.classes.indexOf(y);
            double[] prob = new double[this.k];
            for (int j = 0; j < this.k - 1; ++j) {
                prob[j] = SparseLogisticRegression.dot(x, this.w[j]);
            }
            MathEx.softmax(prob);
            for (int i = 0; i < this.k - 1; ++i) {
                double[] wi = this.w[i];
                double err = (y == i ? 1.0 : 0.0) - prob[i];
                int n = this.p;
                wi[n] = wi[n] + this.eta * err;
                for (SparseArray.Entry e : x) {
                    int n2 = e.i;
                    wi[n2] = wi[n2] + this.eta * err * e.x;
                }
                if (!(this.lambda > 0.0)) continue;
                for (int j = 0; j < this.p; ++j) {
                    int n3 = j;
                    wi[n3] = wi[n3] - this.eta * this.lambda * wi[j];
                }
            }
        }
    }

    static class MultinomialObjective
    implements DifferentiableMultivariateFunction {
        SparseDataset x;
        int[] y;
        int k;
        int p;
        double lambda;
        int partitionSize;
        int partitions;
        double[][] gradients;
        double[][] posterioris;

        MultinomialObjective(SparseDataset x, int[] y, int k, double lambda) {
            this.x = x;
            this.y = y;
            this.k = k;
            this.lambda = lambda;
            this.p = x.ncol();
            this.partitionSize = Integer.parseInt(System.getProperty("smile.data.partition.size", "1000"));
            this.partitions = x.size() / this.partitionSize + (x.size() % this.partitionSize == 0 ? 0 : 1);
            this.gradients = new double[this.partitions][(k - 1) * (this.p + 1)];
            this.posterioris = new double[this.partitions][k];
        }

        @Override
        public double f(double[] w) {
            double f = IntStream.range(0, this.partitions).parallel().mapToDouble(r -> {
                double[] posteriori = this.posterioris[r];
                int begin = r * this.partitionSize;
                int end = (r + 1) * this.partitionSize;
                if (end > this.x.size()) {
                    end = this.x.size();
                }
                return IntStream.range(begin, end).sequential().mapToDouble(i -> {
                    SparseArray xi = (SparseArray)this.x.get(i);
                    posteriori[this.k - 1] = 0.0;
                    for (int j = 0; j < this.k - 1; ++j) {
                        posteriori[j] = SparseLogisticRegression.dot(xi, w, j, this.p);
                    }
                    MathEx.softmax(posteriori);
                    return -MathEx.log(posteriori[this.y[i]]);
                }).sum();
            }).sum();
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.k - 1; ++i) {
                    int pos = i * (this.p + 1);
                    for (int j = 0; j < this.p; ++j) {
                        double wi = w[pos + j];
                        wnorm += wi * wi;
                    }
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }

        @Override
        public double g(double[] w, double[] g) {
            double f = IntStream.range(0, this.partitions).parallel().mapToDouble(r -> {
                double[] posteriori = this.posterioris[r];
                double[] gradient = this.gradients[r];
                Arrays.fill(gradient, 0.0);
                int begin = r * this.partitionSize;
                int end = (r + 1) * this.partitionSize;
                if (end > this.x.size()) {
                    end = this.x.size();
                }
                return IntStream.range(begin, end).sequential().mapToDouble(i -> {
                    int j;
                    SparseArray xi = (SparseArray)this.x.get(i);
                    posteriori[this.k - 1] = 0.0;
                    for (j = 0; j < this.k - 1; ++j) {
                        posteriori[j] = SparseLogisticRegression.dot(xi, w, j, this.p);
                    }
                    MathEx.softmax(posteriori);
                    for (j = 0; j < this.k - 1; ++j) {
                        double err = (this.y[i] == j ? 1.0 : 0.0) - posteriori[j];
                        int pos = j * (this.p + 1);
                        for (SparseArray.Entry e : xi) {
                            int n = pos + e.i;
                            gradient[n] = gradient[n] - err * e.x;
                        }
                        int n = pos + this.p;
                        gradient[n] = gradient[n] - err;
                    }
                    return -MathEx.log(posteriori[this.y[i]]);
                }).sum();
            }).sum();
            Arrays.fill(g, 0.0);
            for (double[] gradient : this.gradients) {
                for (int i = 0; i < g.length; ++i) {
                    int n = i;
                    g[n] = g[n] + gradient[i];
                }
            }
            if (this.lambda > 0.0) {
                double wnorm = 0.0;
                for (int i = 0; i < this.k - 1; ++i) {
                    int pos = i * (this.p + 1);
                    for (int j = 0; j < this.p; ++j) {
                        double wi = w[pos + j];
                        wnorm += wi * wi;
                        int n = pos + j;
                        g[n] = g[n] + this.lambda * wi;
                    }
                }
                f += 0.5 * this.lambda * wnorm;
            }
            return f;
        }
    }
}

