package smile.classification;

import com.sun.jna.platform.win32.COM.tlb.imp.TlbConst;
import java.util.Arrays;
import java.util.Properties;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.base.cart.CART;
import smile.base.cart.Loss;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.formula.Formula;
import smile.data.type.DataTypes;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.data.vector.BaseVector;
import smile.feature.importance.SHAP;
import smile.math.MathEx;
import smile.regression.RegressionTree;
import smile.util.IntSet;
import smile.util.Strings;

/* loaded from: input_file:smile/classification/GradientTreeBoost.class */
public class GradientTreeBoost extends AbstractClassifier<Tuple> implements DataFrameClassifier, SHAP<Tuple> {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger((Class<?>) GradientTreeBoost.class);
    private final Formula formula;
    private final int k;
    private RegressionTree[] trees;
    private RegressionTree[][] forest;
    private final double[] importance;
    private double b;
    private final double shrinkage;

    public GradientTreeBoost(Formula formula, RegressionTree[] regressionTreeArr, double d, double d2, double[] dArr) {
        this(formula, regressionTreeArr, d, d2, dArr, IntSet.of(2));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[] regressionTreeArr, double d, double d2, double[] dArr, IntSet intSet) {
        super(intSet);
        this.b = CMAESOptimizer.DEFAULT_STOPFITNESS;
        this.formula = formula;
        this.k = 2;
        this.trees = regressionTreeArr;
        this.b = d;
        this.shrinkage = d2;
        this.importance = dArr;
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] regressionTreeArr, double d, double[] dArr) {
        this(formula, regressionTreeArr, d, dArr, IntSet.of(regressionTreeArr.length));
    }

    public GradientTreeBoost(Formula formula, RegressionTree[][] regressionTreeArr, double d, double[] dArr, IntSet intSet) {
        super(intSet);
        this.b = CMAESOptimizer.DEFAULT_STOPFITNESS;
        this.formula = formula;
        this.k = regressionTreeArr.length;
        this.forest = regressionTreeArr;
        this.shrinkage = d;
        this.importance = dArr;
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame dataFrame) {
        return fit(formula, dataFrame, new Properties());
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame dataFrame, Properties properties) {
        return fit(formula, dataFrame, Integer.parseInt(properties.getProperty("smile.gradient_boost.trees", "500")), Integer.parseInt(properties.getProperty("smile.gradient_boost.max_depth", "20")), Integer.parseInt(properties.getProperty("smile.gradient_boost.max_nodes", "6")), Integer.parseInt(properties.getProperty("smile.gradient_boost.node_size", TlbConst.TYPELIB_MINOR_VERSION_OFFICE)), Double.parseDouble(properties.getProperty("smile.gradient_boost.shrinkage", "0.05")), Double.parseDouble(properties.getProperty("smile.gradient_boost.sampling_rate", "0.7")));
    }

    public static GradientTreeBoost fit(Formula formula, DataFrame dataFrame, int i, int i2, int i3, int i4, double d, double d2) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid number of trees: " + i);
        }
        if (d <= CMAESOptimizer.DEFAULT_STOPFITNESS || d > 1.0d) {
            throw new IllegalArgumentException("Invalid shrinkage: " + d);
        }
        if (d2 <= CMAESOptimizer.DEFAULT_STOPFITNESS || d2 > 1.0d) {
            throw new IllegalArgumentException("Invalid sampling fraction: " + d2);
        }
        Formula expand = formula.expand(dataFrame.schema());
        DataFrame x = expand.x(dataFrame);
        BaseVector y = expand.y(dataFrame);
        int[][] order = CART.order(x);
        ClassLabels fit = ClassLabels.fit((BaseVector<?, ?, ?>) y);
        return fit.k == 2 ? train2(expand, x, fit, order, i, i2, i3, i4, d, d2) : traink(expand, x, fit, order, i, i2, i3, i4, d, d2);
    }

    @Override // smile.classification.DataFrameClassifier, smile.feature.importance.TreeSHAP
    public Formula formula() {
        return this.formula;
    }

    @Override // smile.classification.DataFrameClassifier
    public StructType schema() {
        return this.trees != null ? this.trees[0].schema() : this.forest[0][0].schema();
    }

    public double[] importance() {
        return this.importance;
    }

    private static GradientTreeBoost train2(Formula formula, DataFrame dataFrame, ClassLabels classLabels, int[][] iArr, int i, int i2, int i3, int i4, double d, double d2) {
        int nrow = dataFrame.nrow();
        int i5 = classLabels.k;
        int[] iArr2 = classLabels.y;
        int[] iArr3 = new int[i5];
        for (int i6 = 0; i6 < nrow; i6++) {
            int i7 = iArr2[i6];
            iArr3[i7] = iArr3[i7] + 1;
        }
        Loss logistic = Loss.logistic(iArr2);
        double intercept = logistic.intercept(null);
        double[] residual = logistic.residual();
        StructField structField = new StructField("residual", DataTypes.DoubleType);
        RegressionTree[] regressionTreeArr = new RegressionTree[i];
        int[] array = IntStream.range(0, nrow).toArray();
        int[] iArr4 = new int[nrow];
        for (int i8 = 0; i8 < i; i8++) {
            sampling(iArr4, array, iArr3, iArr2, d2);
            logger.info("Training {} tree", Strings.ordinal(i8 + 1));
            RegressionTree regressionTree = new RegressionTree(dataFrame, logistic, structField, i2, i3, i4, dataFrame.ncol(), iArr4, iArr);
            regressionTreeArr[i8] = regressionTree;
            for (int i9 = 0; i9 < nrow; i9++) {
                int i10 = i9;
                residual[i10] = residual[i10] + (d * regressionTree.predict(dataFrame.get(i9)));
            }
        }
        double[] dArr = new double[dataFrame.ncol()];
        for (RegressionTree regressionTree2 : regressionTreeArr) {
            double[] importance = regressionTree2.importance();
            for (int i11 = 0; i11 < importance.length; i11++) {
                int i12 = i11;
                dArr[i12] = dArr[i12] + importance[i11];
            }
        }
        return new GradientTreeBoost(formula, regressionTreeArr, intercept, d, dArr, classLabels.classes);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static GradientTreeBoost traink(Formula formula, DataFrame dataFrame, ClassLabels classLabels, int[][] iArr, int i, int i2, int i3, int i4, double d, double d2) {
        int nrow = dataFrame.nrow();
        int i5 = classLabels.k;
        int[] iArr2 = classLabels.y;
        int[] iArr3 = new int[i5];
        for (int i6 = 0; i6 < nrow; i6++) {
            int i7 = iArr2[i6];
            iArr3[i7] = iArr3[i7] + 1;
        }
        StructField structField = new StructField("residual", DataTypes.DoubleType);
        RegressionTree[][] regressionTreeArr = new RegressionTree[i5][i];
        double[][] dArr = new double[nrow][i5];
        double[] dArr2 = new double[i5];
        Loss[] lossArr = new Loss[i5];
        for (int i8 = 0; i8 < i5; i8++) {
            lossArr[i8] = Loss.logistic(i8, i5, iArr2, dArr);
            dArr2[i8] = lossArr[i8].residual();
        }
        int[] array = IntStream.range(0, nrow).toArray();
        int[] iArr4 = new int[nrow];
        for (int i9 = 0; i9 < i; i9++) {
            logger.info("Training {} tree", Strings.ordinal(i9 + 1));
            for (int i10 = 0; i10 < nrow; i10++) {
                for (int i11 = 0; i11 < i5; i11++) {
                    dArr[i10][i11] = dArr2[i11][i10];
                }
                MathEx.softmax(dArr[i10]);
            }
            for (int i12 = 0; i12 < i5; i12++) {
                sampling(iArr4, array, iArr3, iArr2, d2);
                RegressionTree regressionTree = new RegressionTree(dataFrame, lossArr[i12], structField, i2, i3, i4, dataFrame.ncol(), iArr4, iArr);
                regressionTreeArr[i12][i9] = regressionTree;
                double[] dArr3 = dArr2[i12];
                for (int i13 = 0; i13 < nrow; i13++) {
                    int i14 = i13;
                    dArr3[i14] = dArr3[i14] + (d * regressionTree.predict(dataFrame.get(i13)));
                }
            }
        }
        double[] dArr4 = new double[dataFrame.ncol()];
        for (RegressionTree[] regressionTreeArr2 : regressionTreeArr) {
            for (RegressionTree regressionTree2 : regressionTreeArr2) {
                double[] importance = regressionTree2.importance();
                for (int i15 = 0; i15 < importance.length; i15++) {
                    int i16 = i15;
                    dArr4[i16] = dArr4[i16] + importance[i15];
                }
            }
        }
        return new GradientTreeBoost(formula, regressionTreeArr, d, dArr4, classLabels.classes);
    }

    private static void sampling(int[] iArr, int[] iArr2, int[] iArr3, int[] iArr4, double d) {
        int length = iArr.length;
        int length2 = iArr3.length;
        Arrays.fill(iArr, 0);
        MathEx.permutate(iArr2);
        for (int i = 0; i < length2; i++) {
            int round = (int) Math.round(iArr3[i] * d);
            int i2 = 0;
            for (int i3 = 0; i3 < length && i2 < round; i3++) {
                int i4 = iArr2[i3];
                if (iArr4[i4] == i) {
                    iArr[i4] = 1;
                    i2++;
                }
            }
        }
    }

    public int size() {
        return trees().length;
    }

    public RegressionTree[] trees() {
        return this.trees != null ? this.trees : (RegressionTree[]) Arrays.stream(this.forest).flatMap((v0) -> {
            return Arrays.stream(v0);
        }).toArray(i -> {
            return new RegressionTree[i];
        });
    }

    public void trim(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        if (this.k == 2) {
            if (i > this.trees.length) {
                throw new IllegalArgumentException("The new model size is larger than the current size.");
            }
            if (i < this.trees.length) {
                this.trees = (RegressionTree[]) Arrays.copyOf(this.trees, i);
                return;
            }
            return;
        }
        if (i > this.forest[0].length) {
            throw new IllegalArgumentException("The new model size is larger than the current one.");
        }
        if (i < this.forest[0].length) {
            for (int i2 = 0; i2 < this.forest.length; i2++) {
                this.forest[i2] = (RegressionTree[]) Arrays.copyOf(this.forest[i2], i);
            }
        }
    }

    @Override // smile.classification.Classifier
    public int predict(Tuple tuple) {
        Tuple x = this.formula.x(tuple);
        if (this.k == 2) {
            double d = this.b;
            for (RegressionTree regressionTree : this.trees) {
                d += this.shrinkage * regressionTree.predict(x);
            }
            return this.classes.valueOf(d > CMAESOptimizer.DEFAULT_STOPFITNESS ? 1 : 0);
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.k; i2++) {
            double d3 = 0.0d;
            for (RegressionTree regressionTree2 : this.forest[i2]) {
                d3 += this.shrinkage * regressionTree2.predict(x);
            }
            if (d3 > d2) {
                d2 = d3;
                i = i2;
            }
        }
        return this.classes.valueOf(i);
    }

    @Override // smile.classification.Classifier
    public boolean soft() {
        return true;
    }

    @Override // smile.classification.Classifier
    public int predict(Tuple tuple, double[] dArr) {
        if (dArr.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.k)));
        }
        Tuple x = this.formula.x(tuple);
        if (this.k == 2) {
            double d = this.b;
            for (RegressionTree regressionTree : this.trees) {
                d += this.shrinkage * regressionTree.predict(x);
            }
            dArr[0] = 1.0d / (1.0d + Math.exp(2.0d * d));
            dArr[1] = 1.0d - dArr[0];
            return this.classes.valueOf(d > CMAESOptimizer.DEFAULT_STOPFITNESS ? 1 : 0);
        }
        double d2 = Double.NEGATIVE_INFINITY;
        int i = -1;
        for (int i2 = 0; i2 < this.k; i2++) {
            dArr[i2] = 0.0d;
            for (RegressionTree regressionTree2 : this.forest[i2]) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (this.shrinkage * regressionTree2.predict(x));
            }
            if (dArr[i2] > d2) {
                d2 = dArr[i2];
                i = i2;
            }
        }
        double d3 = 0.0d;
        for (int i4 = 0; i4 < this.k; i4++) {
            dArr[i4] = Math.exp(dArr[i4] - d2);
            d3 += dArr[i4];
        }
        for (int i5 = 0; i5 < this.k; i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d3;
        }
        return this.classes.valueOf(i);
    }

    public int[][] test(DataFrame dataFrame) {
        DataFrame x = this.formula.x(dataFrame);
        int nrow = x.nrow();
        int length = this.trees != null ? this.trees.length : this.forest[0].length;
        int[][] iArr = new int[length][nrow];
        if (this.k == 2) {
            for (int i = 0; i < nrow; i++) {
                Tuple tuple = x.get(i);
                double d = 0.0d;
                for (int i2 = 0; i2 < length; i2++) {
                    d += this.shrinkage * this.trees[i2].predict(tuple);
                    iArr[i2][i] = d > CMAESOptimizer.DEFAULT_STOPFITNESS ? 1 : 0;
                }
            }
        } else {
            double[] dArr = new double[this.k];
            for (int i3 = 0; i3 < nrow; i3++) {
                Tuple tuple2 = x.get(i3);
                Arrays.fill(dArr, CMAESOptimizer.DEFAULT_STOPFITNESS);
                for (int i4 = 0; i4 < length; i4++) {
                    for (int i5 = 0; i5 < this.k; i5++) {
                        int i6 = i5;
                        dArr[i6] = dArr[i6] + (this.shrinkage * this.forest[i5][i4].predict(tuple2));
                    }
                    iArr[i4][i3] = MathEx.whichMax(dArr);
                }
            }
        }
        return iArr;
    }

    public double[] shap(DataFrame dataFrame) {
        this.formula.bind(dataFrame.schema());
        return shap((Stream) dataFrame.stream().parallel());
    }

    @Override // smile.feature.importance.SHAP
    public double[] shap(Tuple tuple) {
        int length;
        Tuple x = this.formula.x(tuple);
        int length2 = x.length();
        double[] dArr = new double[length2 * this.k];
        if (this.trees != null) {
            length = this.trees.length;
            for (RegressionTree regressionTree : this.trees) {
                double[] shap = regressionTree.shap(x);
                for (int i = 0; i < length2; i++) {
                    int i2 = 2 * i;
                    dArr[i2] = dArr[i2] + shap[i];
                    int i3 = (2 * i) + 1;
                    dArr[i3] = dArr[i3] + shap[i];
                }
            }
        } else {
            length = this.forest[0].length;
            for (int i4 = 0; i4 < this.k; i4++) {
                for (RegressionTree regressionTree2 : this.forest[i4]) {
                    double[] shap2 = regressionTree2.shap(x);
                    for (int i5 = 0; i5 < length2; i5++) {
                        int i6 = (i5 * this.k) + i4;
                        dArr[i6] = dArr[i6] + shap2[i5];
                    }
                }
            }
        }
        for (int i7 = 0; i7 < dArr.length; i7++) {
            int i8 = i7;
            dArr[i8] = dArr[i8] / length;
        }
        return dArr;
    }
}
