package smile.regression;

import java.util.Arrays;
import java.util.Properties;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.data.type.StructType;
import smile.math.MathEx;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/regression/RidgeRegression.class */
public class RidgeRegression {
    public static LinearModel fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Double.parseDouble(properties.getProperty("smile.ridge.lambda", "1")));
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double d) {
        double[] dArr = new double[dataFrame.size()];
        Arrays.fill(dArr, 1.0d);
        return fit(formula, dataFrame, dArr, new double[]{d}, new double[]{CMAESOptimizer.DEFAULT_STOPFITNESS});
    }

    public static LinearModel fit(Formula formula, DataFrame dataFrame, double[] dArr, double[] dArr2, double[] dArr3) {
        Formula expand = formula.expand(dataFrame.schema());
        StructType bind = expand.bind(dataFrame.schema());
        Matrix matrix = expand.matrix(dataFrame, false);
        double[] doubleArray = expand.y(dataFrame).toDoubleArray();
        int nrow = matrix.nrow();
        int ncol = matrix.ncol();
        if (dArr.length != nrow) {
            throw new IllegalArgumentException(String.format("Invalid weights vector size: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(nrow)));
        }
        for (int i = 0; i < nrow; i++) {
            if (dArr[i] <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
                throw new IllegalArgumentException(String.format("Invalid weights[%d] = %f", Integer.valueOf(i), Double.valueOf(dArr[i])));
            }
        }
        if (dArr2.length == 1) {
            double d = dArr2[0];
            dArr2 = new double[ncol];
            Arrays.fill(dArr2, d);
        } else if (dArr2.length != ncol) {
            throw new IllegalArgumentException(String.format("Invalid lambda vector size: %d != %d", Integer.valueOf(dArr2.length), Integer.valueOf(ncol)));
        }
        for (int i2 = 0; i2 < ncol; i2++) {
            if (dArr2[i2] < CMAESOptimizer.DEFAULT_STOPFITNESS) {
                throw new IllegalArgumentException(String.format("Invalid lambda[%d] = %f", Integer.valueOf(i2), Double.valueOf(dArr2[i2])));
            }
        }
        if (dArr3.length == 1) {
            double d2 = dArr3[0];
            dArr3 = new double[ncol];
            Arrays.fill(dArr3, d2);
        } else if (dArr3.length != ncol) {
            throw new IllegalArgumentException(String.format("Invalid beta0 vector size: %d != %d", Integer.valueOf(dArr3.length), Integer.valueOf(ncol)));
        }
        double[] colMeans = matrix.colMeans();
        double[] colSds = matrix.colSds();
        for (int i3 = 0; i3 < colSds.length; i3++) {
            if (MathEx.isZero(colSds[i3])) {
                throw new IllegalArgumentException(String.format("The column '%s' is constant", matrix.colName(i3)));
            }
        }
        Matrix scale = matrix.scale(colMeans, colSds);
        Matrix matrix2 = new Matrix(ncol, nrow);
        for (int i4 = 0; i4 < ncol; i4++) {
            for (int i5 = 0; i5 < nrow; i5++) {
                matrix2.set(i4, i5, dArr[i5] * scale.get(i5, i4));
            }
        }
        double[] mv = matrix2.mv(doubleArray);
        for (int i6 = 0; i6 < ncol; i6++) {
            int i7 = i6;
            mv[i7] = mv[i7] + (dArr2[i6] * dArr3[i6]);
        }
        Matrix mm = matrix2.mm(scale);
        mm.uplo(UPLO.LOWER);
        mm.addDiag(dArr2);
        double[] solve = mm.cholesky(true).solve(mv);
        for (int i8 = 0; i8 < ncol; i8++) {
            int i9 = i8;
            solve[i9] = solve[i9] / colSds[i8];
        }
        return new LinearModel(expand, bind, matrix, doubleArray, solve, MathEx.mean(doubleArray) - MathEx.dot(solve, colMeans));
    }
}
