package smile.base.svm;

import java.lang.reflect.Array;
import java.util.stream.IntStream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

/* loaded from: input_file:smile/base/svm/OCSVM.class */
public class OCSVM<T> {
    private static final Logger logger = LoggerFactory.getLogger(OCSVM.class);
    private static final double TAU = 1.0E-12d;
    private final MercerKernel<T> kernel;
    private final double nu;
    private final double tol;
    private double C;
    private T[] x;
    private double rho;
    private double[] alpha;
    private double[] O;
    private double[][] K;
    private int svmin = -1;
    private int svmax = -1;
    private double omin = Double.MAX_VALUE;
    private double omax = -1.7976931348623157E308d;

    public OCSVM(MercerKernel<T> mercerKernel, double d, double d2) {
        if (d <= CMAESOptimizer.DEFAULT_STOPFITNESS || d > 1.0d) {
            throw new IllegalArgumentException("Invalid nu: " + d);
        }
        if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid tolerance of convergence test:" + d2);
        }
        this.kernel = mercerKernel;
        this.nu = d;
        this.tol = d2;
    }

    public KernelMachine<T> fit(T[] tArr) {
        this.x = tArr;
        int length = tArr.length;
        this.K = new double[length][length];
        IntStream.range(0, length).parallel().forEach(i -> {
            Object obj = tArr[i];
            double[] dArr = this.K[i];
            for (int i = 0; i < length; i++) {
                dArr[i] = this.kernel.k(obj, tArr[i]);
            }
        });
        int round = (int) Math.round(this.nu * length);
        this.C = 1.0d / round;
        int[] permutate = MathEx.permutate(length);
        this.alpha = new double[length];
        for (int i2 = 0; i2 < round; i2++) {
            this.alpha[permutate[i2]] = this.C;
        }
        this.O = new double[length];
        this.rho = Double.NEGATIVE_INFINITY;
        for (int i3 = 0; i3 < length; i3++) {
            double[] dArr = this.K[i3];
            for (int i4 = 0; i4 < length; i4++) {
                double[] dArr2 = this.O;
                int i5 = i3;
                dArr2[i5] = dArr2[i5] + (dArr[i4] * this.alpha[i4]);
            }
            if (this.alpha[i3] > CMAESOptimizer.DEFAULT_STOPFITNESS && this.rho < this.O[i3]) {
                this.rho = this.O[i3];
            }
        }
        minmax();
        int min = Math.min(length, 1000);
        int i6 = 1;
        while (smo(this.tol)) {
            if (i6 % min == 0) {
                logger.info("{} SMO iterations", Integer.valueOf(i6));
            }
            i6++;
        }
        int i7 = 0;
        int i8 = 0;
        for (int i9 = 0; i9 < length; i9++) {
            if (this.alpha[i9] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                i7++;
                if (this.alpha[i9] == this.C) {
                    i8++;
                }
            }
        }
        Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), i7);
        double[] dArr3 = new double[i7];
        double d = -(this.rho - this.tol);
        int i10 = 0;
        for (int i11 = 0; i11 < length; i11++) {
            if (this.alpha[i11] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                objArr[i10] = tArr[i11];
                int i12 = i10;
                i10++;
                dArr3[i12] = this.alpha[i11];
            }
        }
        logger.info("{} samples, {} support vectors, {} bounded", Integer.valueOf(length), Integer.valueOf(i7), Integer.valueOf(i8));
        return new KernelMachine<>(this.kernel, objArr, dArr3, d);
    }

    private void minmax() {
        this.svmin = -1;
        this.svmax = -1;
        this.omin = Double.MAX_VALUE;
        this.omax = -1.7976931348623157E308d;
        int length = this.x.length;
        for (int i = 0; i < length; i++) {
            double d = this.O[i];
            double d2 = this.alpha[i];
            if (d < this.omin && d2 < this.C) {
                this.svmin = i;
                this.omin = d;
            }
            if (d > this.omax && d2 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                this.svmax = i;
                this.omax = d;
            }
        }
    }

    private boolean smo(double d) {
        int i = this.svmin;
        int i2 = this.svmax;
        int length = this.x.length;
        if (i2 < 0) {
            double d2 = this.O[i];
            double[] dArr = this.K[i];
            double d3 = dArr[i];
            double d4 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                double d5 = this.O[i3] - d2;
                double d6 = (d3 + this.K[i3][i3]) - (2.0d * dArr[i3]);
                if (d6 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d6 = 1.0E-12d;
                }
                double d7 = d5 / d6;
                if (this.O[i3] > d2 && this.alpha[i3] > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    double d8 = (-d5) * d7;
                    if (d8 < d4) {
                        d4 = d8;
                        i2 = i3;
                    }
                }
            }
        }
        if (i < 0) {
            double d9 = this.O[i2];
            double[] dArr2 = this.K[i2];
            double d10 = dArr2[i2];
            double d11 = 0.0d;
            for (int i4 = 0; i4 < length; i4++) {
                double d12 = d9 - this.O[i4];
                double d13 = (d10 + this.K[i4][i4]) - (2.0d * dArr2[i4]);
                if (d13 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    d13 = 1.0E-12d;
                }
                double d14 = d12 / d13;
                if (this.O[i4] < d9 && this.alpha[i4] < this.C) {
                    double d15 = (-d12) * d14;
                    if (d15 < d11) {
                        d11 = d15;
                        i = i4;
                    }
                }
            }
        }
        if (i < 0 || i2 < 0) {
            return false;
        }
        double d16 = this.alpha[i];
        double d17 = this.alpha[i2];
        double[] dArr3 = this.K[i];
        double[] dArr4 = this.K[i2];
        double d18 = (this.K[i][i] + this.K[i2][i2]) - (2.0d * this.K[i][i2]);
        if (d18 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            d18 = 1.0E-12d;
        }
        double d19 = (this.O[i] - this.O[i2]) / d18;
        double d20 = this.alpha[i] + this.alpha[i2];
        double[] dArr5 = this.alpha;
        int i5 = i2;
        dArr5[i5] = dArr5[i5] + d19;
        double[] dArr6 = this.alpha;
        int i6 = i;
        dArr6[i6] = dArr6[i6] - d19;
        if (d20 > this.C) {
            if (this.alpha[i] > this.C) {
                this.alpha[i] = this.C;
                this.alpha[i2] = d20 - this.C;
            }
        } else if (this.alpha[i2] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this.alpha[i2] = 0.0d;
            this.alpha[i] = d20;
        }
        if (d20 > this.C) {
            if (this.alpha[i2] > this.C) {
                this.alpha[i2] = this.C;
                this.alpha[i] = d20 - this.C;
            }
        } else if (this.alpha[i] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            this.alpha[i] = 0.0d;
            this.alpha[i2] = d20;
        }
        double d21 = this.alpha[i] - d16;
        double d22 = this.alpha[i2] - d17;
        for (int i7 = 0; i7 < length; i7++) {
            double[] dArr7 = this.O;
            int i8 = i7;
            dArr7[i8] = dArr7[i8] + (dArr3[i7] * d21) + (dArr4[i7] * d22);
        }
        this.rho = (this.omax + this.omin) / 2.0d;
        minmax();
        return this.omax - this.omin > d;
    }
}
