package smile.math.matrix;

import java.io.Serializable;
import java.nio.DoubleBuffer;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.blas.BLAS;
import smile.math.blas.LAPACK;
import smile.math.blas.Layout;
import smile.math.blas.Transpose;
import smile.math.blas.UPLO;

/* loaded from: input_file:smile/math/matrix/BandMatrix.class */
public class BandMatrix extends IMatrix {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) BandMatrix.class);
    final double[] AB;
    final int m;
    final int n;
    final int kl;
    final int ku;
    final int ld;
    UPLO uplo;

    /* loaded from: input_file:smile/math/matrix/BandMatrix$Cholesky.class */
    public static class Cholesky implements Serializable {
        private static final long serialVersionUID = 2;
        public final BandMatrix lu;

        public Cholesky(BandMatrix bandMatrix) {
            if (bandMatrix.nrow() != bandMatrix.ncol()) {
                throw new UnsupportedOperationException("Cholesky constructor on a non-square matrix");
            }
            this.lu = bandMatrix;
        }

        public double det() {
            double d = 1.0d;
            for (int i = 0; i < this.lu.n; i++) {
                d *= this.lu.get(i, i);
            }
            return d * d;
        }

        public double logdet() {
            int i = this.lu.n;
            double d = 0.0d;
            for (int i2 = 0; i2 < i; i2++) {
                d += Math.log(this.lu.get(i2, i2));
            }
            return 2.0d * d;
        }

        public Matrix inverse() {
            Matrix eye = Matrix.eye(this.lu.n);
            solve(eye);
            return eye;
        }

        public double[] solve(double[] dArr) {
            Matrix column = Matrix.column(dArr);
            solve(column);
            return column.A;
        }

        public void solve(Matrix matrix) {
            if (matrix.m != this.lu.m) {
                throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", Integer.valueOf(this.lu.m), Integer.valueOf(this.lu.n), Integer.valueOf(matrix.m), Integer.valueOf(matrix.n)));
            }
            int pbtrs = LAPACK.engine.pbtrs(this.lu.layout(), this.lu.uplo, this.lu.n, this.lu.uplo == UPLO.LOWER ? this.lu.kl : this.lu.ku, matrix.n, this.lu.AB, this.lu.ld, matrix.A, matrix.ld);
            if (pbtrs != 0) {
                BandMatrix.logger.error("LAPACK POTRS error code: {}", Integer.valueOf(pbtrs));
                throw new ArithmeticException("LAPACK POTRS error code: " + pbtrs);
            }
        }
    }

    /* loaded from: input_file:smile/math/matrix/BandMatrix$LU.class */
    public static class LU implements Serializable {
        private static final long serialVersionUID = 2;
        public final BandMatrix lu;
        public final int[] ipiv;
        public final int info;

        public LU(BandMatrix bandMatrix, int[] iArr, int i) {
            this.lu = bandMatrix;
            this.ipiv = iArr;
            this.info = i;
        }

        public boolean isSingular() {
            return this.info > 0;
        }

        public double det() {
            int i = this.lu.m;
            int i2 = this.lu.n;
            if (i != i2) {
                throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", Integer.valueOf(i), Integer.valueOf(i2)));
            }
            double d = 1.0d;
            for (int i3 = 0; i3 < i2; i3++) {
                d *= this.lu.AB[(i3 * this.lu.ld) + (this.lu.kl / 2) + this.lu.ku];
            }
            for (int i4 = 0; i4 < i2; i4++) {
                if (i4 + 1 != this.ipiv[i4]) {
                    d = -d;
                }
            }
            return d;
        }

        public Matrix inverse() {
            Matrix eye = Matrix.eye(this.lu.n);
            solve(eye);
            return eye;
        }

        public double[] solve(double[] dArr) {
            Matrix column = Matrix.column(dArr);
            solve(column);
            return column.A;
        }

        public void solve(Matrix matrix) {
            if (this.lu.m != this.lu.n) {
                throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", Integer.valueOf(this.lu.m), Integer.valueOf(this.lu.n)));
            }
            if (matrix.m != this.lu.m) {
                throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", Integer.valueOf(this.lu.m), Integer.valueOf(this.lu.n), Integer.valueOf(matrix.m), Integer.valueOf(matrix.n)));
            }
            if (this.lu.layout() != matrix.layout()) {
                throw new IllegalArgumentException("The matrix layout is inconsistent.");
            }
            if (this.info > 0) {
                throw new RuntimeException("The matrix is singular.");
            }
            int gbtrs = LAPACK.engine.gbtrs(this.lu.layout(), Transpose.NO_TRANSPOSE, this.lu.n, this.lu.kl / 2, this.lu.ku, matrix.n, this.lu.AB, this.lu.ld, this.ipiv, matrix.A, matrix.ld);
            if (gbtrs != 0) {
                BandMatrix.logger.error("LAPACK GETRS error code: {}", Integer.valueOf(gbtrs));
                throw new ArithmeticException("LAPACK GETRS error code: " + gbtrs);
            }
        }
    }

    public BandMatrix(int i, int i2, int i3, int i4) {
        this.uplo = null;
        if (i <= 0 || i2 <= 0) {
            throw new IllegalArgumentException(String.format("Invalid matrix size: %d x %d", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        if (i3 < 0 || i4 < 0) {
            throw new IllegalArgumentException(String.format("Invalid subdiagonals or superdiagonals: kl = %d, ku = %d", Integer.valueOf(i3), Integer.valueOf(i4)));
        }
        if (i3 >= i) {
            throw new IllegalArgumentException(String.format("Invalid subdiagonals %d >= %d", Integer.valueOf(i3), Integer.valueOf(i)));
        }
        if (i4 >= i2) {
            throw new IllegalArgumentException(String.format("Invalid superdiagonals %d >= %d", Integer.valueOf(i4), Integer.valueOf(i2)));
        }
        this.m = i;
        this.n = i2;
        this.kl = i3;
        this.ku = i4;
        this.ld = i3 + i4 + 1;
        this.AB = new double[this.ld * i2];
    }

    public BandMatrix(int i, int i2, int i3, int i4, double[][] dArr) {
        this(i, i2, i3, i4);
        for (int i5 = 0; i5 < i2; i5++) {
            for (int i6 = 0; i6 < this.ld; i6++) {
                this.AB[(i5 * this.ld) + i6] = dArr[i6][i5];
            }
        }
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BandMatrix m1453clone() {
        BandMatrix bandMatrix = new BandMatrix(this.m, this.n, this.kl, this.ku);
        System.arraycopy(this.AB, 0, bandMatrix.AB, 0, this.AB.length);
        if (this.m == this.n && this.kl == this.ku) {
            bandMatrix.uplo(this.uplo);
        }
        return bandMatrix;
    }

    @Override // smile.math.matrix.IMatrix
    public int nrow() {
        return this.m;
    }

    @Override // smile.math.matrix.IMatrix
    public int ncol() {
        return this.n;
    }

    @Override // smile.math.matrix.IMatrix
    public long size() {
        return this.AB.length;
    }

    public int kl() {
        return this.kl;
    }

    public int ku() {
        return this.ku;
    }

    public Layout layout() {
        return Layout.COL_MAJOR;
    }

    public int ld() {
        return this.ld;
    }

    public boolean isSymmetric() {
        return this.uplo != null;
    }

    public BandMatrix uplo(UPLO uplo) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", Integer.valueOf(this.m), Integer.valueOf(this.n)));
        }
        if (this.kl != this.ku) {
            throw new IllegalArgumentException(String.format("kl != ku: %d != %d", Integer.valueOf(this.kl), Integer.valueOf(this.ku)));
        }
        this.uplo = uplo;
        return this;
    }

    public UPLO uplo() {
        return this.uplo;
    }

    public boolean equals(Object obj) {
        if (obj instanceof BandMatrix) {
            return equals((BandMatrix) obj, 1.0E-10d);
        }
        return false;
    }

    public boolean equals(BandMatrix bandMatrix, double d) {
        if (this.m != bandMatrix.m || this.n != bandMatrix.n) {
            return false;
        }
        for (int i = 0; i < this.n; i++) {
            for (int i2 = 0; i2 < this.m; i2++) {
                if (!MathEx.isZero(get(i2, i) - bandMatrix.get(i2, i), d)) {
                    return false;
                }
            }
        }
        return true;
    }

    @Override // smile.math.matrix.IMatrix
    public double get(int i, int i2) {
        return (Math.max(0, i2 - this.ku) > i || i > Math.min(this.m - 1, i2 + this.kl)) ? CMAESOptimizer.DEFAULT_STOPFITNESS : this.AB[(((i2 * this.ld) + this.ku) + i) - i2];
    }

    @Override // smile.math.matrix.IMatrix
    public void set(int i, int i2, double d) {
        if (Math.max(0, i2 - this.ku) > i || i > Math.min(this.m - 1, i2 + this.kl)) {
            throw new UnsupportedOperationException(String.format("Set element at (%d, %d)", Integer.valueOf(i), Integer.valueOf(i2)));
        }
        this.AB[(((i2 * this.ld) + this.ku) + i) - i2] = d;
    }

    @Override // smile.math.matrix.IMatrix
    public void mv(Transpose transpose, double d, double[] dArr, double d2, double[] dArr2) {
        if (this.uplo != null) {
            BLAS.engine.sbmv(layout(), this.uplo, this.n, this.kl, d, this.AB, this.ld, dArr, 1, d2, dArr2, 1);
        } else {
            BLAS.engine.gbmv(layout(), transpose, this.m, this.n, this.kl, this.ku, d, this.AB, this.ld, dArr, 1, d2, dArr2, 1);
        }
    }

    @Override // smile.math.matrix.IMatrix
    public void mv(double[] dArr, int i, int i2) {
        DoubleBuffer wrap = DoubleBuffer.wrap(dArr, i, this.n);
        DoubleBuffer wrap2 = DoubleBuffer.wrap(dArr, i2, this.m);
        if (this.uplo != null) {
            BLAS.engine.sbmv(layout(), this.uplo, this.n, this.kl, 1.0d, DoubleBuffer.wrap(this.AB), this.ld, wrap, 1, CMAESOptimizer.DEFAULT_STOPFITNESS, wrap2, 1);
        } else {
            BLAS.engine.gbmv(layout(), Transpose.NO_TRANSPOSE, this.m, this.n, this.kl, this.ku, 1.0d, DoubleBuffer.wrap(this.AB), this.ld, wrap, 1, CMAESOptimizer.DEFAULT_STOPFITNESS, wrap2, 1);
        }
    }

    @Override // smile.math.matrix.IMatrix
    public void tv(double[] dArr, int i, int i2) {
        DoubleBuffer wrap = DoubleBuffer.wrap(dArr, i, this.m);
        DoubleBuffer wrap2 = DoubleBuffer.wrap(dArr, i2, this.n);
        if (this.uplo != null) {
            BLAS.engine.sbmv(layout(), this.uplo, this.n, this.kl, 1.0d, DoubleBuffer.wrap(this.AB), this.ld, wrap, 1, CMAESOptimizer.DEFAULT_STOPFITNESS, wrap2, 1);
        } else {
            BLAS.engine.gbmv(layout(), Transpose.TRANSPOSE, this.m, this.n, this.kl, this.ku, 1.0d, DoubleBuffer.wrap(this.AB), this.ld, wrap, 1, CMAESOptimizer.DEFAULT_STOPFITNESS, wrap2, 1);
        }
    }

    public LU lu() {
        BandMatrix bandMatrix = new BandMatrix(this.m, this.n, 2 * this.kl, this.ku);
        for (int i = 0; i < this.n; i++) {
            for (int i2 = 0; i2 < this.ld; i2++) {
                bandMatrix.AB[(i * bandMatrix.ld) + this.kl + i2] = this.AB[(i * this.ld) + i2];
            }
        }
        int[] iArr = new int[this.n];
        int gbtrf = LAPACK.engine.gbtrf(bandMatrix.layout(), bandMatrix.m, bandMatrix.n, bandMatrix.kl / 2, bandMatrix.ku, bandMatrix.AB, bandMatrix.ld, iArr);
        if (gbtrf >= 0) {
            return new LU(bandMatrix, iArr, gbtrf);
        }
        logger.error("LAPACK GBTRF error code: {}", Integer.valueOf(gbtrf));
        throw new ArithmeticException("LAPACK GBTRF error code: " + gbtrf);
    }

    public Cholesky cholesky() {
        if (this.uplo == null) {
            throw new IllegalArgumentException("The matrix is not symmetric");
        }
        BandMatrix bandMatrix = new BandMatrix(this.m, this.n, this.uplo == UPLO.LOWER ? this.kl : 0, this.uplo == UPLO.LOWER ? 0 : this.ku);
        bandMatrix.uplo = this.uplo;
        if (this.uplo == UPLO.LOWER) {
            for (int i = 0; i < this.n; i++) {
                for (int i2 = 0; i2 <= this.kl; i2++) {
                    bandMatrix.AB[(i * bandMatrix.ld) + i2] = get(i + i2, i);
                }
            }
        } else {
            for (int i3 = 0; i3 < this.n; i3++) {
                for (int i4 = 0; i4 <= this.ku; i4++) {
                    bandMatrix.AB[((i3 * bandMatrix.ld) + this.ku) - i4] = get(i3 - i4, i3);
                }
            }
        }
        int pbtrf = LAPACK.engine.pbtrf(bandMatrix.layout(), bandMatrix.uplo, bandMatrix.n, bandMatrix.uplo == UPLO.LOWER ? bandMatrix.kl : bandMatrix.ku, bandMatrix.AB, bandMatrix.ld);
        if (pbtrf == 0) {
            return new Cholesky(bandMatrix);
        }
        logger.error("LAPACK PBTRF error code: {}", Integer.valueOf(pbtrf));
        throw new ArithmeticException("LAPACK PBTRF error code: " + pbtrf);
    }
}
