package smile.base.cart;

import java.math.BigInteger;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
import smile.data.Tuple;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.math.MathEx;

/* loaded from: input_file:smile/base/cart/InternalNode.class */
public abstract class InternalNode implements Node {
    int size;
    Node trueChild;
    Node falseChild;
    int feature;
    double score;
    double deviance;

    public InternalNode(int i, double d, double d2, Node node, Node node2) {
        this.size = node.size() + node2.size();
        this.feature = i;
        this.score = d;
        this.deviance = d2;
        this.trueChild = node;
        this.falseChild = node2;
    }

    @Override // smile.base.cart.Node
    public abstract LeafNode predict(Tuple tuple);

    public abstract boolean branch(Tuple tuple);

    public abstract InternalNode replace(Node node, Node node2);

    public Node trueChild() {
        return this.trueChild;
    }

    public Node falseChild() {
        return this.falseChild;
    }

    public int feature() {
        return this.feature;
    }

    public double score() {
        return this.score;
    }

    @Override // smile.base.cart.Node
    public int size() {
        return this.size;
    }

    @Override // smile.base.cart.Node
    public int leaves() {
        return this.trueChild.leaves() + this.falseChild.leaves();
    }

    @Override // smile.base.cart.Node
    public double deviance() {
        return this.deviance;
    }

    @Override // smile.base.cart.Node
    public int depth() {
        return Math.max(this.trueChild.depth(), this.falseChild.depth()) + 1;
    }

    @Override // smile.base.cart.Node
    public Node merge() {
        this.trueChild = this.trueChild.merge();
        this.falseChild = this.falseChild.merge();
        if ((this.trueChild instanceof DecisionNode) && (this.falseChild instanceof DecisionNode)) {
            if (((DecisionNode) this.trueChild).output() == ((DecisionNode) this.falseChild).output()) {
                int[] count = ((DecisionNode) this.trueChild).count();
                int[] count2 = ((DecisionNode) this.falseChild).count();
                int[] iArr = new int[count.length];
                for (int i = 0; i < iArr.length; i++) {
                    iArr[i] = count[i] + count2[i];
                }
                return new DecisionNode(iArr);
            }
        } else if ((this.trueChild instanceof RegressionNode) && (this.falseChild instanceof RegressionNode) && ((RegressionNode) this.trueChild).output() == ((RegressionNode) this.falseChild).output()) {
            RegressionNode regressionNode = (RegressionNode) this.trueChild;
            RegressionNode regressionNode2 = (RegressionNode) this.falseChild;
            return new RegressionNode(this.size, regressionNode.output(), ((regressionNode.size * regressionNode.mean()) + (regressionNode2.size * regressionNode2.mean())) / this.size, regressionNode.impurity() + regressionNode2.impurity());
        }
        return this;
    }

    public abstract String toString(StructType structType, boolean z);

    @Override // smile.base.cart.Node
    public int[] toString(StructType structType, StructField structField, InternalNode internalNode, int i, BigInteger bigInteger, List<String> list) {
        BigInteger shiftLeft = bigInteger.shiftLeft(1);
        int[] node = this.falseChild.toString(structType, structField, this, i + 1, shiftLeft.add(BigInteger.ONE), list);
        int[] node2 = this.trueChild.toString(structType, structField, this, i + 1, shiftLeft, list);
        int length = node.length;
        int[] iArr = new int[length];
        if (length == 1) {
            iArr[0] = this.size;
        } else {
            iArr = new int[length];
            for (int i2 = 0; i2 < length; i2++) {
                iArr[i2] = node[i2] + node2[i2];
            }
        }
        StringBuilder sb = new StringBuilder();
        for (int i3 = 0; i3 < i; i3++) {
            sb.append(StringUtils.SPACE);
        }
        sb.append(bigInteger).append(") ");
        sb.append(internalNode == null ? "root" : internalNode.toString(structType, this == internalNode.trueChild)).append(StringUtils.SPACE);
        sb.append(this.size).append(StringUtils.SPACE);
        sb.append(String.format("%.5g", Double.valueOf(deviance()))).append(StringUtils.SPACE);
        if (length == 1) {
            sb.append(String.format("%g", Double.valueOf(sumy() / this.size))).append(StringUtils.SPACE);
        } else {
            sb.append(structField.toString(Integer.valueOf(MathEx.whichMax(iArr)))).append(StringUtils.SPACE);
            double[] dArr = new double[iArr.length];
            DecisionNode.posteriori(iArr, dArr);
            sb.append((String) Arrays.stream(dArr).mapToObj(d -> {
                return String.format("%.5g", Double.valueOf(d));
            }).collect(Collectors.joining(StringUtils.SPACE, "(", ")")));
        }
        list.add(sb.toString());
        return iArr;
    }

    private double sumy() {
        double output;
        double output2;
        if (this.trueChild instanceof InternalNode) {
            output = ((InternalNode) this.trueChild).sumy();
        } else {
            if (!(this.trueChild instanceof RegressionNode)) {
                throw new IllegalStateException("Call sumy() on DecisionTree?");
            }
            output = ((RegressionNode) this.trueChild).output() * r0.size();
        }
        if (this.falseChild instanceof InternalNode) {
            output2 = ((InternalNode) this.falseChild).sumy();
        } else {
            if (!(this.falseChild instanceof RegressionNode)) {
                throw new IllegalStateException("Call sumy() on DecisionTree?");
            }
            output2 = ((RegressionNode) this.falseChild).output() * r0.size();
        }
        return output + output2;
    }
}
