package smile.regression;

import com.sun.jna.platform.win32.COM.tlb.imp.TlbConst;
import java.util.Arrays;
import java.util.Properties;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.blas.Transpose;
import smile.math.matrix.IMatrix;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/regression/LASSO.class */
public class LASSO {
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) LASSO.class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:smile/regression/LASSO$PCG.class */
    public static class PCG extends IMatrix implements IMatrix.Preconditioner {
        Matrix A;
        Matrix AtA;
        int p;
        double[] d1;
        double[] d2;
        double[] prb;
        double[] prs;
        double[] ax;
        double[] atax;

        PCG(Matrix matrix, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
            this.A = matrix;
            this.d1 = dArr;
            this.d2 = dArr2;
            this.prb = dArr3;
            this.prs = dArr4;
            int nrow = matrix.nrow();
            this.p = matrix.ncol();
            this.ax = new double[nrow];
            this.atax = new double[this.p];
            if (matrix.ncol() < 10000) {
                this.AtA = matrix.ata();
            }
        }

        @Override // smile.math.matrix.IMatrix
        public int nrow() {
            return 2 * this.p;
        }

        @Override // smile.math.matrix.IMatrix
        public int ncol() {
            return 2 * this.p;
        }

        @Override // smile.math.matrix.IMatrix
        public long size() {
            return this.A.size();
        }

        @Override // smile.math.matrix.IMatrix
        public void mv(double[] dArr, double[] dArr2) {
            if (this.AtA != null) {
                this.AtA.mv(dArr, this.atax);
            } else {
                this.A.mv(dArr, this.ax);
                this.A.tv(this.ax, this.atax);
            }
            for (int i = 0; i < this.p; i++) {
                dArr2[i] = (2.0d * this.atax[i]) + (this.d1[i] * dArr[i]) + (this.d2[i] * dArr[i + this.p]);
                dArr2[i + this.p] = (this.d2[i] * dArr[i]) + (this.d1[i] * dArr[i + this.p]);
            }
        }

        @Override // smile.math.matrix.IMatrix
        public void tv(double[] dArr, double[] dArr2) {
            mv(dArr, dArr2);
        }

        @Override // smile.math.matrix.IMatrix.Preconditioner
        public void asolve(double[] dArr, double[] dArr2) {
            for (int i = 0; i < this.p; i++) {
                dArr2[i] = ((this.d1[i] * dArr[i]) - (this.d2[i] * dArr[i + this.p])) / this.prs[i];
                dArr2[i + this.p] = (((-this.d2[i]) * dArr[i]) + (this.prb[i] * dArr[i + this.p])) / this.prs[i];
            }
        }

        @Override // smile.math.matrix.IMatrix
        public void mv(Transpose transpose, double d, double[] dArr, double d2, double[] dArr2) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.IMatrix
        public void mv(double[] dArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }

        @Override // smile.math.matrix.IMatrix
        public void tv(double[] dArr, int i, int i2) {
            throw new UnsupportedOperationException();
        }
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.parseDouble(properties.getProperty("smile.lasso.lambda", TlbConst.TYPELIB_MAJOR_VERSION_SHELL)), Double.parseDouble(properties.getProperty("smile.lasso.tolerance", "1E-4")), Integer.parseInt(properties.getProperty("smile.lasso.iterations", "1000")));
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d) {
        return fit(formula, dataFrame, d, 1.0E-4d, 1000);
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d, double d2, int i) {
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        for (int i2 = 0; i2 < colSds.length; i2++) {
            if (MathEx.isZero(colSds[i2])) {
                throw new IllegalArgumentException(String.format("The column '%s' is constant", matrix.colName(i2)));
            }
        }
        double[] train = train(matrix.scale(colMeans, colSds), doubleArray, d, d2, i);
        int length = train.length;
        for (int i3 = 0; i3 < length; i3++) {
            int i4 = i3;
            train[i4] = train[i4] / colSds[i3];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, train, MathEx.mean(doubleArray) - MathEx.dot(train, colMeans));
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static double[] train(Matrix matrix, double[] dArr, double d, double d2, int i) {
        if (d < CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + d);
        }
        if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid tolerance: " + d2);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        boolean z = false;
        int nrow = matrix.nrow();
        int ncol = matrix.ncol();
        double[] dArr2 = new double[nrow];
        double mean = MathEx.mean(dArr);
        for (int i2 = 0; i2 < nrow; i2++) {
            dArr2[i2] = dArr[i2] - mean;
        }
        double min = Math.min(Math.max(1.0d, 1.0d / d), (2 * ncol) / 0.001d);
        double d3 = Double.NEGATIVE_INFINITY;
        double d4 = Double.POSITIVE_INFINITY;
        double[] dArr3 = new double[ncol];
        double[] dArr4 = new double[ncol];
        double[] dArr5 = new double[nrow];
        double[][] dArr6 = new double[2][ncol];
        Arrays.fill(dArr4, 1.0d);
        for (int i3 = 0; i3 < ncol; i3++) {
            dArr6[0][i3] = dArr3[i3] - dArr4[i3];
            dArr6[1][i3] = (-dArr3[i3]) - dArr4[i3];
        }
        double[] dArr7 = new double[ncol];
        double[] dArr8 = new double[ncol];
        double[] dArr9 = new double[nrow];
        double[][] dArr10 = new double[2][ncol];
        double[] dArr11 = new double[ncol];
        double[] dArr12 = new double[ncol];
        double[] dArr13 = new double[2 * ncol];
        double[] dArr14 = new double[2 * ncol];
        double[] dArr15 = new double[ncol];
        Arrays.fill(dArr15, 2.0d);
        double[] dArr16 = new double[nrow];
        double[] dArr17 = new double[ncol];
        double[] dArr18 = new double[ncol];
        double[] dArr19 = new double[ncol];
        double[] dArr20 = new double[ncol];
        double[] dArr21 = new double[ncol];
        double[][] dArr22 = new double[2][ncol];
        double[] dArr23 = new double[ncol];
        double[] dArr24 = new double[ncol];
        PCG pcg = new PCG(matrix, dArr20, dArr21, dArr23, dArr24);
        int i4 = 0;
        while (true) {
            if (i4 > i) {
                break;
            }
            matrix.mv(dArr3, dArr5);
            for (int i5 = 0; i5 < nrow; i5++) {
                int i6 = i5;
                dArr5[i6] = dArr5[i6] - dArr2[i5];
                dArr16[i5] = 2.0d * dArr5[i5];
            }
            matrix.tv(dArr16, dArr17);
            double normInf = MathEx.normInf(dArr17);
            if (normInf > d) {
                double d5 = d / normInf;
                for (int i7 = 0; i7 < nrow; i7++) {
                    int i8 = i7;
                    dArr16[i8] = dArr16[i8] * d5;
                }
            }
            double dot = MathEx.dot(dArr5, dArr5) + (d * MathEx.norm1(dArr3));
            d3 = Math.max(((-0.25d) * MathEx.dot(dArr16, dArr16)) - MathEx.dot(dArr16, dArr2), d3);
            if (i4 % 10 == 0) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", Integer.valueOf(i4), Double.valueOf(dot), Double.valueOf(d3)));
            }
            double d6 = dot - d3;
            if (d6 / d3 < d2) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", Integer.valueOf(i4), Double.valueOf(dot), Double.valueOf(d3)));
                break;
            }
            if (d4 >= 0.5d) {
                min = Math.max(Math.min(((2 * ncol) * 2) / d6, 2.0d * min), min);
            }
            for (int i9 = 0; i9 < ncol; i9++) {
                double d7 = 1.0d / (dArr4[i9] + dArr3[i9]);
                double d8 = 1.0d / (dArr4[i9] - dArr3[i9]);
                dArr18[i9] = d7;
                dArr19[i9] = d8;
                dArr20[i9] = ((d7 * d7) + (d8 * d8)) / min;
                dArr21[i9] = ((d7 * d7) - (d8 * d8)) / min;
            }
            matrix.tv(dArr5, dArr22[0]);
            for (int i10 = 0; i10 < ncol; i10++) {
                dArr22[0][i10] = (2.0d * dArr22[0][i10]) - ((dArr18[i10] - dArr19[i10]) / min);
                dArr22[1][i10] = d - ((dArr18[i10] + dArr19[i10]) / min);
                dArr14[i10] = -dArr22[0][i10];
                dArr14[i10 + ncol] = -dArr22[1][i10];
            }
            for (int i11 = 0; i11 < ncol; i11++) {
                dArr23[i11] = dArr15[i11] + dArr20[i11];
                dArr24[i11] = (dArr23[i11] * dArr20[i11]) - (dArr21[i11] * dArr21[i11]);
            }
            double min2 = Math.min(0.1d, (0.001d * d6) / Math.min(1.0d, MathEx.norm(dArr14)));
            if (i4 != 0 && !z) {
                min2 *= 0.1d;
            }
            if (pcg.solve(dArr14, dArr13, pcg, min2, 1, 5000) > min2) {
                z = 5000;
            }
            for (int i12 = 0; i12 < ncol; i12++) {
                dArr11[i12] = dArr13[i12];
                dArr12[i12] = dArr13[i12 + ncol];
            }
            double dot2 = (MathEx.dot(dArr5, dArr5) + (d * MathEx.sum(dArr4))) - (sumlogneg(dArr6) / min);
            d4 = 1.0d;
            double dot3 = MathEx.dot(dArr14, dArr13);
            int i13 = 0;
            while (i13 < 100) {
                for (int i14 = 0; i14 < ncol; i14++) {
                    dArr7[i14] = dArr3[i14] + (d4 * dArr11[i14]);
                    dArr8[i14] = dArr4[i14] + (d4 * dArr12[i14]);
                    dArr10[0][i14] = dArr7[i14] - dArr8[i14];
                    dArr10[1][i14] = (-dArr7[i14]) - dArr8[i14];
                }
                if (MathEx.max(dArr10) < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    matrix.mv(dArr7, dArr9);
                    for (int i15 = 0; i15 < nrow; i15++) {
                        int i16 = i15;
                        dArr9[i16] = dArr9[i16] - dArr2[i15];
                    }
                    if (((MathEx.dot(dArr9, dArr9) + (d * MathEx.sum(dArr8))) - (sumlogneg(dArr10) / min)) - dot2 <= 0.01d * d4 * dot3) {
                        break;
                    }
                }
                d4 = 0.5d * d4;
                i13++;
            }
            if (i13 == 100) {
                logger.error("LASSO: Too many iterations of line search.");
                break;
            }
            System.arraycopy(dArr7, 0, dArr3, 0, ncol);
            System.arraycopy(dArr8, 0, dArr4, 0, ncol);
            System.arraycopy(dArr10[0], 0, dArr6[0], 0, ncol);
            System.arraycopy(dArr10[1], 0, dArr6[1], 0, ncol);
            i4++;
        }
        if (i4 == i) {
            logger.error("LASSO: Too many iterations.");
        }
        return dArr3;
    }

    private static double sumlogneg(double[][] dArr) {
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            for (double d2 : dArr2) {
                d += Math.log(-d2);
            }
        }
        return d;
    }
}
