/*
 * Decompiled with CFR 0.152.
 */
package smile.base.mlp;

import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.OutputFunction;

public class OutputLayer
extends Layer {
    private static final long serialVersionUID = 2L;
    private final Cost cost;
    private final OutputFunction activation;

    public OutputLayer(int n, int p, OutputFunction activation, Cost cost) {
        super(n, p);
        switch (cost) {
            case MEAN_SQUARED_ERROR: {
                if (activation != OutputFunction.SOFTMAX) break;
                throw new IllegalArgumentException("Softmax output function is not allowed with mean squared error cost function");
            }
            case LIKELIHOOD: {
                if (activation != OutputFunction.LINEAR) break;
                throw new IllegalArgumentException("Linear output function is not allowed with likelihood cost function");
            }
        }
        this.activation = activation;
        this.cost = cost;
    }

    public String toString() {
        return String.format("%s(%d) | %s", new Object[]{this.activation.name(), this.n, this.cost});
    }

    public Cost cost() {
        return this.cost;
    }

    @Override
    public void transform(double[] x) {
        this.activation.f(x);
    }

    @Override
    public void backpropagate(double[] lowerLayerGradient) {
        this.weight.tv((double[])this.outputGradient.get(), lowerLayerGradient);
    }

    public void computeOutputGradient(double[] target, double weight) {
        int i;
        double[] output = (double[])this.output.get();
        double[] outputGradient = (double[])this.outputGradient.get();
        int n = output.length;
        if (target.length != n) {
            throw new IllegalArgumentException(String.format("Invalid target vector size: %d, expected: %d", target.length, n));
        }
        for (i = 0; i < n; ++i) {
            outputGradient[i] = target[i] - output[i];
        }
        this.activation.g(this.cost, outputGradient, output);
        if (weight > 0.0 && weight != 1.0) {
            i = 0;
            while (i < n) {
                int n2 = i++;
                outputGradient[n2] = outputGradient[n2] * weight;
            }
        }
    }
}

