/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import smile.base.mlp.ActivationFunction;
import smile.base.mlp.Cost;
import smile.base.mlp.HiddenLayerBuilder;
import smile.base.mlp.InputLayer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.OutputFunction;
import smile.base.mlp.OutputLayerBuilder;
import smile.math.MathEx;
import smile.math.matrix.Matrix;

public abstract class Layer
implements Serializable {
    private static final long serialVersionUID = 2L;
    protected final int n;
    protected final int p;
    protected final double dropout;
    protected Matrix weight;
    protected double[] bias;
    protected transient ThreadLocal<double[]> output;
    protected transient ThreadLocal<double[]> outputGradient;
    protected transient ThreadLocal<Matrix> weightGradient;
    protected transient ThreadLocal<double[]> biasGradient;
    protected transient ThreadLocal<Matrix> weightGradientMoment1;
    protected transient ThreadLocal<Matrix> weightGradientMoment2;
    protected transient ThreadLocal<double[]> biasGradientMoment1;
    protected transient ThreadLocal<double[]> biasGradientMoment2;
    protected transient ThreadLocal<Matrix> weightUpdate;
    protected transient ThreadLocal<double[]> biasUpdate;
    protected transient ThreadLocal<byte[]> mask;

    Layer(int n, double dropout) {
        if (dropout < 0.0 || dropout >= 1.0) {
            throw new IllegalArgumentException("Invalid dropout rate: " + dropout);
        }
        this.n = n;
        this.p = n;
        this.dropout = dropout;
        this.output = ThreadLocal.withInitial(() -> new double[n]);
        if (dropout > 0.0) {
            this.mask = ThreadLocal.withInitial(() -> new byte[n]);
        }
    }

    public Layer(int n, int p) {
        this(n, p, 0.0);
    }

    public Layer(int n, int p, double dropout) {
        this(Matrix.rand(n, p, -Math.sqrt(6.0 / (double)(n + p)), Math.sqrt(6.0 / (double)(n + p))), new double[n], dropout);
    }

    public Layer(Matrix weight, double[] bias) {
        this(weight, bias, 0.0);
    }

    public Layer(Matrix weight, double[] bias, double dropout) {
        if (dropout < 0.0 || dropout >= 1.0) {
            throw new IllegalArgumentException("Invalid dropout rate: " + dropout);
        }
        this.n = weight.nrow();
        this.p = weight.ncol();
        this.weight = weight;
        this.bias = bias;
        this.dropout = dropout;
        this.init();
    }

    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        this.init();
    }

    private void init() {
        this.output = ThreadLocal.withInitial(() -> new double[this.n]);
        this.outputGradient = ThreadLocal.withInitial(() -> new double[this.n]);
        this.weightGradient = ThreadLocal.withInitial(() -> new Matrix(this.n, this.p));
        this.biasGradient = ThreadLocal.withInitial(() -> new double[this.n]);
        this.weightGradientMoment1 = ThreadLocal.withInitial(() -> new Matrix(this.n, this.p));
        this.weightGradientMoment2 = ThreadLocal.withInitial(() -> new Matrix(this.n, this.p));
        this.biasGradientMoment1 = ThreadLocal.withInitial(() -> new double[this.n]);
        this.biasGradientMoment2 = ThreadLocal.withInitial(() -> new double[this.n]);
        this.weightUpdate = ThreadLocal.withInitial(() -> new Matrix(this.n, this.p));
        this.biasUpdate = ThreadLocal.withInitial(() -> new double[this.n]);
        if (this.dropout > 0.0) {
            this.mask = ThreadLocal.withInitial(() -> new byte[this.n]);
        }
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output.get();
    }

    public double[] gradient() {
        return this.outputGradient.get();
    }

    public void propagate(double[] x) {
        double[] output = this.output.get();
        System.arraycopy(this.bias, 0, output, 0, this.n);
        this.weight.mv(1.0, x, 1.0, output);
        this.transform(output);
    }

    public void propagateDropout() {
        if (this.dropout > 0.0) {
            double[] output = this.output.get();
            byte[] mask = this.mask.get();
            double scale = 1.0 / (1.0 - this.dropout);
            int i = 0;
            while (i < this.n) {
                byte retain;
                mask[i] = retain = (byte)(!(MathEx.random() < this.dropout) ? 1 : 0);
                int n = i++;
                output[n] = output[n] * ((double)retain * scale);
            }
        }
    }

    public abstract void transform(double[] var1);

    public abstract void backpropagate(double[] var1);

    public void backpopagateDropout() {
        if (this.dropout > 0.0) {
            double[] gradient = this.outputGradient.get();
            byte[] mask = this.mask.get();
            double scale = 1.0 / (1.0 - this.dropout);
            for (int i = 0; i < this.n; ++i) {
                int n = i;
                gradient[n] = gradient[n] * ((double)mask[i] * scale);
            }
        }
    }

    public void computeGradientUpdate(double[] x, double learningRate, double momentum, double decay) {
        double[] outputGradient = this.outputGradient.get();
        if (momentum > 0.0 && momentum < 1.0) {
            Matrix weightUpdate = this.weightUpdate.get();
            double[] biasUpdate = this.biasUpdate.get();
            weightUpdate.mul(momentum);
            weightUpdate.add(learningRate, outputGradient, x);
            this.weight.add(weightUpdate);
            int i = 0;
            while (i < this.n) {
                double b;
                biasUpdate[i] = b = momentum * biasUpdate[i] + learningRate * outputGradient[i];
                int n = i++;
                this.bias[n] = this.bias[n] + b;
            }
        } else {
            this.weight.add(learningRate, outputGradient, x);
            for (int i = 0; i < this.n; ++i) {
                int n = i;
                this.bias[n] = this.bias[n] + learningRate * outputGradient[i];
            }
        }
        if (decay > 0.9 && decay < 1.0) {
            this.weight.mul(decay);
        }
    }

    public void computeGradient(double[] x) {
        double[] outputGradient = this.outputGradient.get();
        Matrix weightGradient = this.weightGradient.get();
        double[] biasGradient = this.biasGradient.get();
        weightGradient.add(1.0, outputGradient, x);
        for (int i = 0; i < this.n; ++i) {
            int n = i;
            biasGradient[n] = biasGradient[n] + outputGradient[i];
        }
    }

    public void update(int m, double learningRate, double momentum, double decay, double rho, double epsilon) {
        Matrix weightGradient = this.weightGradient.get();
        double[] biasGradient = this.biasGradient.get();
        double eta = learningRate / (double)m;
        if (rho > 0.0 && rho < 1.0) {
            int i;
            int i2;
            int j;
            eta = learningRate;
            weightGradient.div(m);
            int i3 = 0;
            while (i3 < this.n) {
                int n = i3++;
                biasGradient[n] = biasGradient[n] / (double)m;
            }
            Matrix rmsWeightGradient = this.weightGradientMoment2.get();
            double[] rmsBiasGradient = this.biasGradientMoment2.get();
            double rho1 = 1.0 - rho;
            for (j = 0; j < this.p; ++j) {
                for (i2 = 0; i2 < this.n; ++i2) {
                    rmsWeightGradient.set(i2, j, rho * rmsWeightGradient.get(i2, j) + rho1 * MathEx.pow2(weightGradient.get(i2, j)));
                }
            }
            for (i = 0; i < this.n; ++i) {
                rmsBiasGradient[i] = rho * rmsBiasGradient[i] + rho1 * MathEx.pow2(biasGradient[i]);
            }
            for (j = 0; j < this.p; ++j) {
                for (i2 = 0; i2 < this.n; ++i2) {
                    weightGradient.div(i2, j, Math.sqrt(epsilon + rmsWeightGradient.get(i2, j)));
                }
            }
            for (i = 0; i < this.n; ++i) {
                int n = i;
                biasGradient[n] = biasGradient[n] / Math.sqrt(epsilon + rmsBiasGradient[i]);
            }
        }
        if (momentum > 0.0 && momentum < 1.0) {
            Matrix weightUpdate = this.weightUpdate.get();
            double[] biasUpdate = this.biasUpdate.get();
            weightUpdate.add(momentum, eta, weightGradient);
            for (int i = 0; i < this.n; ++i) {
                biasUpdate[i] = momentum * biasUpdate[i] + eta * biasGradient[i];
            }
            this.weight.add(weightUpdate);
            MathEx.add(this.bias, biasUpdate);
        } else {
            this.weight.add(eta, weightGradient);
            for (int i = 0; i < this.n; ++i) {
                int n = i;
                this.bias[n] = this.bias[n] + eta * biasGradient[i];
            }
        }
        if (decay > 0.9 && decay < 1.0) {
            this.weight.mul(decay);
        }
        weightGradient.fill(0.0);
        Arrays.fill(biasGradient, 0.0);
    }

    public static HiddenLayerBuilder builder(String activation, int neurons, double dropout, double param) {
        switch (activation.toLowerCase(Locale.ROOT)) {
            case "relu": {
                return Layer.rectifier(neurons, dropout);
            }
            case "sigmoid": {
                return Layer.sigmoid(neurons, dropout);
            }
            case "tanh": {
                return Layer.tanh(neurons, dropout);
            }
            case "linear": {
                return Layer.linear(neurons, dropout);
            }
            case "leaky": {
                if (Double.isNaN(param)) {
                    return Layer.leaky(neurons, dropout);
                }
                return Layer.leaky(neurons, dropout, param);
            }
        }
        throw new IllegalArgumentException("Unsupported activation function: " + activation);
    }

    public static LayerBuilder input(int neurons) {
        return Layer.input(neurons, 0.0);
    }

    public static LayerBuilder input(int neurons, double dropout) {
        return new LayerBuilder(neurons, dropout){

            @Override
            public InputLayer build(int p) {
                return new InputLayer(this.neurons, this.dropout);
            }
        };
    }

    public static HiddenLayerBuilder linear(int neurons) {
        return Layer.linear(neurons, 0.0);
    }

    public static HiddenLayerBuilder linear(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int neurons) {
        return Layer.rectifier(neurons, 0.0);
    }

    public static HiddenLayerBuilder rectifier(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder leaky(int neurons) {
        return Layer.rectifier(neurons, 0.0);
    }

    public static HiddenLayerBuilder leaky(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.leaky());
    }

    public static HiddenLayerBuilder leaky(int neurons, double dropout, double a) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.leaky(a));
    }

    public static HiddenLayerBuilder sigmoid(int neurons) {
        return Layer.sigmoid(neurons, 0.0);
    }

    public static HiddenLayerBuilder sigmoid(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int neurons) {
        return Layer.tanh(neurons, 0.0);
    }

    public static HiddenLayerBuilder tanh(int neurons, double dropout) {
        return new HiddenLayerBuilder(neurons, dropout, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int neurons, OutputFunction output) {
        return new OutputLayerBuilder(neurons, output, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int neurons, OutputFunction output) {
        return new OutputLayerBuilder(neurons, output, Cost.LIKELIHOOD);
    }

    public static LayerBuilder[] of(int k, int p, String spec) {
        Pattern regex = Pattern.compile(String.format("(\\w+)\\((%s)(,\\s*(%s))?(,\\s*(%s))?\\)", "[-+]?\\d{1,9}", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?", "[-+]?[0-9]*\\.?[0-9]+(?:[eE][-+]?[0-9]+)?"));
        String[] layers = spec.split("\\|");
        ArrayList<LayerBuilder> builders = new ArrayList<LayerBuilder>();
        for (int i = 0; i < layers.length; ++i) {
            Matcher m = regex.matcher(layers[i]);
            if (m.matches()) {
                String activation = m.group(1);
                int neurons = Integer.parseInt(m.group(2));
                double dropout = 0.0;
                if (m.group(3) != null) {
                    dropout = Double.parseDouble(m.group(4));
                }
                double param = Double.NaN;
                if (m.group(5) != null) {
                    param = Double.parseDouble(m.group(6));
                }
                if (i == 0) {
                    if (activation.equalsIgnoreCase("input")) {
                        builders.add(Layer.input(neurons, dropout));
                        continue;
                    }
                    builders.add(Layer.input(p));
                    builders.add(Layer.builder(activation, neurons, dropout, param));
                    continue;
                }
                builders.add(Layer.builder(activation, neurons, dropout, param));
                continue;
            }
            throw new IllegalArgumentException("Invalid layer: " + layers[i]);
        }
        if (k < 2) {
            builders.add(Layer.mse(1, OutputFunction.LINEAR));
        } else if (k == 2) {
            builders.add(Layer.mle(1, OutputFunction.SIGMOID));
        } else {
            builders.add(Layer.mle(k, OutputFunction.SOFTMAX));
        }
        return builders.toArray(new LayerBuilder[0]);
    }
}

