package smile.regression;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.function.ToDoubleFunction;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import smile.data.Dataset;
import smile.data.Instance;

/* loaded from: input_file:smile/regression/Regression.class */
public interface Regression<T> extends ToDoubleFunction<T>, Serializable {

    /* loaded from: input_file:smile/regression/Regression$Trainer.class */
    public interface Trainer<T, M extends Regression<T>> {
        default M fit(T[] tArr, double[] dArr) {
            return fit(tArr, dArr, new Properties());
        }

        M fit(T[] tArr, double[] dArr, Properties properties);
    }

    double predict(T t);

    @Override // java.util.function.ToDoubleFunction
    default double applyAsDouble(T t) {
        return predict((Regression<T>) t);
    }

    default double[] predict(T[] tArr) {
        return Arrays.stream(tArr).mapToDouble(this::predict).toArray();
    }

    default double[] predict(List<T> list) {
        return list.stream().mapToDouble(this::predict).toArray();
    }

    default double[] predict(Dataset<T> dataset) {
        return dataset.stream().mapToDouble(this::predict).toArray();
    }

    default boolean online() {
        try {
            update((Regression<T>) null, CMAESOptimizer.DEFAULT_STOPFITNESS);
            return false;
        } catch (UnsupportedOperationException e) {
            return !e.getMessage().equals("update a batch learner");
        } catch (Exception e2) {
            return true;
        }
    }

    default void update(T t, double d) {
        throw new UnsupportedOperationException("update a batch learner");
    }

    default void update(T[] tArr, double[] dArr) {
        if (tArr.length != dArr.length) {
            throw new IllegalArgumentException(String.format("Input vector x of size %d not equal to length %d of y", Integer.valueOf(tArr.length), Integer.valueOf(dArr.length)));
        }
        for (int i = 0; i < tArr.length; i++) {
            update((Regression<T>) tArr[i], dArr[i]);
        }
    }

    default void update(Dataset<Instance<T>> dataset) {
        dataset.stream().forEach(instance -> {
            update((Regression<T>) instance.x(), instance.y());
        });
    }

    @SafeVarargs
    static <T> Regression<T> ensemble(final Regression<T>... regressionArr) {
        return new Regression<T>() { // from class: smile.regression.Regression.1
            private final boolean online;

            {
                this.online = Arrays.stream(regressionArr).allMatch((v0) -> {
                    return v0.online();
                });
            }

            @Override // smile.regression.Regression
            public boolean online() {
                return this.online;
            }

            @Override // smile.regression.Regression
            public double predict(T t) {
                double d = 0.0d;
                for (Regression regression : regressionArr) {
                    d += regression.predict((Regression) t);
                }
                return d / regressionArr.length;
            }

            @Override // smile.regression.Regression
            public void update(T t, double d) {
                for (Regression regression : regressionArr) {
                    regression.update((Regression) t, d);
                }
            }
        };
    }
}
