/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
import java.util.function.ToDoubleFunction;
import smile.data.Dataset;
import smile.data.Instance;

public interface Regression<T>
extends ToDoubleFunction<T>,
Serializable {
    public double predict(T var1);

    @Override
    default public double applyAsDouble(T x) {
        return this.predict(x);
    }

    default public double[] predict(T[] x) {
        return Arrays.stream(x).mapToDouble(this::predict).toArray();
    }

    default public double[] predict(List<T> x) {
        return x.stream().mapToDouble(this::predict).toArray();
    }

    default public double[] predict(Dataset<T> x) {
        return x.stream().mapToDouble(this::predict).toArray();
    }

    default public boolean online() {
        try {
            this.update(null, 0.0);
        }
        catch (UnsupportedOperationException e) {
            return !e.getMessage().equals("update a batch learner");
        }
        catch (Exception e) {
            return true;
        }
        return false;
    }

    default public void update(T x, double y) {
        throw new UnsupportedOperationException("update a batch learner");
    }

    default public void update(T[] x, double[] y) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("Input vector x of size %d not equal to length %d of y", x.length, y.length));
        }
        for (int i = 0; i < x.length; ++i) {
            this.update(x[i], y[i]);
        }
    }

    default public void update(Dataset<Instance<T>> batch) {
        batch.stream().forEach(sample -> this.update(sample.x(), sample.y()));
    }

    @SafeVarargs
    public static <T> Regression<T> ensemble(final Regression<T> ... models) {
        return new Regression<T>(){
            private final boolean online;
            {
                this.online = Arrays.stream(models).allMatch(Regression::online);
            }

            @Override
            public boolean online() {
                return this.online;
            }

            @Override
            public double predict(T x) {
                double y = 0.0;
                for (Regression model : models) {
                    y += model.predict(x);
                }
                return y / (double)models.length;
            }

            @Override
            public void update(T x, double y) {
                for (Regression model : models) {
                    model.update(x, y);
                }
            }
        };
    }

    public static interface Trainer<T, M extends Regression<T>> {
        default public M fit(T[] x, double[] y) {
            Properties params = new Properties();
            return this.fit(x, y, params);
        }

        public M fit(T[] var1, double[] var2, Properties var3);
    }
}

