package smile.math;

import java.util.Arrays;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.matrix.Matrix;

/* loaded from: input_file:smile/math/LevenbergMarquardt.class */
public class LevenbergMarquardt {
    private static final Logger logger = LoggerFactory.getLogger(LevenbergMarquardt.class);
    public final double[] parameters;
    public final double[] fittedValues;
    public final double[] residuals;
    public final double sse;

    LevenbergMarquardt(double[] dArr, double[] dArr2, double[] dArr3, double d) {
        this.parameters = dArr;
        this.fittedValues = dArr2;
        this.residuals = dArr3;
        this.sse = d;
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[] dArr, double[] dArr2, double[] dArr3) {
        return fit(differentiableMultivariateFunction, dArr, dArr2, dArr3, 1.0E-4d, 20);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[] dArr, double[] dArr2, double[] dArr3, double d, int i) {
        if (d <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + d);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int length = dArr.length;
        int length2 = dArr3.length;
        double[] dArr4 = new double[length2 + 1];
        double[] dArr5 = new double[length2 + 1];
        double[] dArr6 = new double[length2 + 1];
        System.arraycopy(dArr3, 0, dArr4, 0, length2);
        double[] dArr7 = new double[length];
        double[] dArr8 = new double[length];
        double[] dArr9 = new double[length2];
        double[] dArr10 = new double[length2];
        double[] dArr11 = new double[length2];
        double[] dArr12 = new double[length2];
        double[] dArr13 = new double[length2];
        Arrays.fill(dArr13, 1.0d);
        Matrix matrix = new Matrix(length, length2);
        double d2 = 1.0d;
        double[] dArr14 = {0.1d, 1.0d, 100.0d, 10000.0d, 1000000.0d};
        for (int i2 = 1; i2 <= i; i2++) {
            System.arraycopy(dArr4, 0, dArr5, 0, length2);
            double d3 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                dArr5[length2] = dArr[i3];
                dArr8[i3] = dArr2[i3] - differentiableMultivariateFunction.g(dArr5, dArr11);
                d3 += dArr8[i3] * dArr8[i3];
                for (int i4 = 0; i4 < length2; i4++) {
                    matrix.set(i3, i4, dArr11[i4]);
                }
            }
            double d4 = d3;
            double d5 = (1.0d - d) * d4;
            for (int i5 = 0; i5 < length2; i5++) {
                double d6 = 0.0d;
                for (int i6 = 0; i6 < length; i6++) {
                    double d7 = matrix.get(i6, i5);
                    d6 += d7 * d7;
                }
                if (d6 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    dArr13[i5] = 1.0d / Math.sqrt(d6);
                } else {
                    dArr13[i5] = 1.0d;
                }
                for (int i7 = 0; i7 < length; i7++) {
                    matrix.mul(i7, i5, dArr13[i5]);
                }
            }
            Matrix.SVD svd = matrix.svd(true, true);
            double[] dArr15 = svd.s;
            double dot = MathEx.dot(dArr15, dArr15);
            Matrix matrix2 = svd.U;
            Matrix matrix3 = svd.V;
            matrix2.tv(dArr8, dArr9);
            int length3 = dArr14.length;
            int i8 = 0;
            while (true) {
                if (i8 >= length3) {
                    break;
                }
                double max = Math.max(d2 * dArr14[i8], 1.0E-7d);
                double sqrt = Math.sqrt(dot + max);
                for (int i9 = 0; i9 < length2; i9++) {
                    dArr10[i9] = dArr9[i9] / sqrt;
                }
                matrix3.mv(dArr10, dArr12);
                for (int i10 = 0; i10 < length2; i10++) {
                    int i11 = i10;
                    dArr12[i11] = dArr12[i11] * dArr13[i10];
                }
                for (int i12 = 0; i12 < length2; i12++) {
                    dArr6[i12] = dArr12[i12] + dArr5[i12];
                }
                d3 = 0.0d;
                for (int i13 = 0; i13 < length; i13++) {
                    dArr6[length2] = dArr[i13];
                    double f = dArr2[i13] - differentiableMultivariateFunction.f(dArr6);
                    d3 += f * f;
                }
                if (d3 < d4) {
                    System.arraycopy(dArr6, 0, dArr4, 0, length2);
                    d4 = d3;
                }
                if (d3 <= d5) {
                    d2 = max;
                    break;
                }
                i8++;
            }
            logger.info(String.format("SSE after %3d iterations: %.5f", Integer.valueOf(i2), Double.valueOf(d4)));
            if (d3 < MathEx.EPSILON || d3 > d5) {
                logger.info(String.format("converges on SSE after %d iterations", Integer.valueOf(i2)));
                break;
            }
        }
        double[] dArr16 = new double[length2];
        System.arraycopy(dArr4, 0, dArr16, 0, length2);
        double d8 = 0.0d;
        for (int i14 = 0; i14 < length; i14++) {
            dArr4[length2] = dArr[i14];
            dArr7[i14] = differentiableMultivariateFunction.f(dArr4);
            dArr8[i14] = dArr2[i14] - dArr7[i14];
            d8 += dArr8[i14] * dArr8[i14];
        }
        return new LevenbergMarquardt(dArr16, dArr7, dArr8, d8);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[][] dArr, double[] dArr2, double[] dArr3) {
        return fit(differentiableMultivariateFunction, dArr, dArr2, dArr3, 1.0E-4d, 20);
    }

    public static LevenbergMarquardt fit(DifferentiableMultivariateFunction differentiableMultivariateFunction, double[][] dArr, double[] dArr2, double[] dArr3, double d, int i) {
        if (d <= CMAESOptimizer.DEFAULT_STOPFITNESS) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + d);
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + i);
        }
        int length = dArr.length;
        int length2 = dArr[0].length;
        int length3 = dArr3.length;
        double[] dArr4 = new double[length3 + length2];
        double[] dArr5 = new double[length3 + length2];
        double[] dArr6 = new double[length3 + length2];
        System.arraycopy(dArr3, 0, dArr4, 0, length3);
        double[] dArr7 = new double[length];
        double[] dArr8 = new double[length];
        double[] dArr9 = new double[length3];
        double[] dArr10 = new double[length3];
        double[] dArr11 = new double[length3];
        double[] dArr12 = new double[length3];
        double[] dArr13 = new double[length3];
        Arrays.fill(dArr13, 1.0d);
        Matrix matrix = new Matrix(length, length3);
        double d2 = 1.0d;
        double[] dArr14 = {0.1d, 1.0d, 100.0d, 10000.0d, 1000000.0d};
        for (int i2 = 1; i2 <= i; i2++) {
            System.arraycopy(dArr4, 0, dArr5, 0, length3);
            double d3 = 0.0d;
            for (int i3 = 0; i3 < length; i3++) {
                System.arraycopy(dArr[i3], 0, dArr5, length3, length2);
                dArr8[i3] = dArr2[i3] - differentiableMultivariateFunction.g(dArr5, dArr11);
                d3 += dArr8[i3] * dArr8[i3];
                for (int i4 = 0; i4 < length3; i4++) {
                    matrix.set(i3, i4, dArr11[i4]);
                }
            }
            double d4 = d3;
            double d5 = (1.0d - d) * d4;
            for (int i5 = 0; i5 < length3; i5++) {
                double d6 = 0.0d;
                for (int i6 = 0; i6 < length; i6++) {
                    double d7 = matrix.get(i6, i5);
                    d6 += d7 * d7;
                }
                if (d6 > CMAESOptimizer.DEFAULT_STOPFITNESS) {
                    dArr13[i5] = 1.0d / Math.sqrt(d6);
                } else {
                    dArr13[i5] = 1.0d;
                }
                for (int i7 = 0; i7 < length; i7++) {
                    matrix.mul(i7, i5, dArr13[i5]);
                }
            }
            Matrix.SVD svd = matrix.svd(true, true);
            double[] dArr15 = svd.s;
            double dot = MathEx.dot(dArr15, dArr15);
            Matrix matrix2 = svd.U;
            Matrix matrix3 = svd.V;
            matrix2.tv(dArr8, dArr9);
            int length4 = dArr14.length;
            int i8 = 0;
            while (true) {
                if (i8 >= length4) {
                    break;
                }
                double max = Math.max(d2 * dArr14[i8], 1.0E-7d);
                double sqrt = Math.sqrt(dot + max);
                for (int i9 = 0; i9 < length3; i9++) {
                    dArr10[i9] = dArr9[i9] / sqrt;
                }
                matrix3.mv(dArr10, dArr12);
                for (int i10 = 0; i10 < length3; i10++) {
                    int i11 = i10;
                    dArr12[i11] = dArr12[i11] * dArr13[i10];
                }
                for (int i12 = 0; i12 < length3; i12++) {
                    dArr6[i12] = dArr12[i12] + dArr5[i12];
                }
                d3 = 0.0d;
                for (int i13 = 0; i13 < length; i13++) {
                    System.arraycopy(dArr[i13], 0, dArr6, length3, length2);
                    double f = dArr2[i13] - differentiableMultivariateFunction.f(dArr6);
                    d3 += f * f;
                }
                if (d3 < d4) {
                    System.arraycopy(dArr6, 0, dArr4, 0, length3);
                    d4 = d3;
                }
                if (d3 <= d5) {
                    d2 = max;
                    break;
                }
                i8++;
            }
            logger.info(String.format("SSE after %3d iterations: %.5f", Integer.valueOf(i2), Double.valueOf(d4)));
            if (d3 < MathEx.EPSILON || d3 > d5) {
                logger.info(String.format("converges on SSE after %d iterations", Integer.valueOf(i2)));
                break;
            }
        }
        double[] dArr16 = new double[length3];
        System.arraycopy(dArr4, 0, dArr16, 0, length3);
        double d8 = 0.0d;
        for (int i14 = 0; i14 < length; i14++) {
            System.arraycopy(dArr[i14], 0, dArr4, length3, length2);
            dArr7[i14] = differentiableMultivariateFunction.f(dArr4);
            dArr8[i14] = dArr2[i14] - dArr7[i14];
            d8 += dArr8[i14] * dArr8[i14];
        }
        return new LevenbergMarquardt(dArr16, dArr7, dArr8, d8);
    }
}
