/*
 * Decompiled with CFR 0.152.
 */
package smile.sequence;

class Trellis {
    public final Cell[][] table;

    public Trellis(int n, int k) {
        this.table = new Cell[n][k];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < k; ++j) {
                this.table[i][j] = new Cell(k);
            }
        }
    }

    public void forward(double[] scaling) {
        int i;
        int T = this.table.length;
        int k = this.table[0].length;
        Cell[] row = this.table[0];
        for (i = 0; i < k; ++i) {
            Cell ti = row[i];
            ti.alpha = ti.expf[0];
        }
        scaling[0] = 0.0;
        for (i = 0; i < k; ++i) {
            scaling[0] = scaling[0] + row[i].alpha;
        }
        for (i = 0; i < k; ++i) {
            row[i].alpha /= scaling[0];
        }
        for (int t2 = 1; t2 < T; ++t2) {
            int i2;
            row = this.table[t2];
            Cell[] row1 = this.table[t2 - 1];
            for (i2 = 0; i2 < k; ++i2) {
                Cell ti = row[i2];
                ti.alpha = 0.0;
                for (int j = 0; j < k; ++j) {
                    ti.alpha += ti.expf[j] * row1[j].alpha;
                }
            }
            scaling[t2] = 0.0;
            for (i2 = 0; i2 < k; ++i2) {
                int n = t2;
                scaling[n] = scaling[n] + row[i2].alpha;
            }
            for (i2 = 0; i2 < k; ++i2) {
                row[i2].alpha /= scaling[t2];
            }
        }
    }

    public void backward() {
        int T = this.table.length - 1;
        int k = this.table[0].length;
        Cell[] row = this.table[T];
        for (int i = 0; i < k; ++i) {
            row[i].beta = 1.0;
        }
        int t2 = T;
        while (t2-- > 0) {
            int i;
            row = this.table[t2];
            Cell[] row1 = this.table[t2 + 1];
            for (int i2 = 0; i2 < k; ++i2) {
                Cell ti = row[i2];
                ti.beta = 0.0;
                for (int j = 0; j < k; ++j) {
                    ti.beta += row1[j].expf[i2] * row1[j].beta;
                }
            }
            double sum = 0.0;
            for (i = 0; i < k; ++i) {
                sum += row[i].beta;
            }
            for (i = 0; i < k; ++i) {
                row[i].beta /= sum;
            }
        }
    }

    public void gradient(double[] scaling, int[] label) {
        int i;
        int T = this.table.length;
        int k = this.table[0].length;
        Cell[] row = this.table[0];
        double Z = 0.0;
        for (i = 0; i < k; ++i) {
            Z += row[i].expf[0] * row[i].beta;
        }
        for (i = 0; i < k; ++i) {
            row[i].residual[0] = label[0] == i ? 1.0 - row[i].expf[0] * row[i].beta / Z : 0.0 - row[i].expf[0] * row[i].beta / Z;
        }
        for (int t2 = 1; t2 < T; ++t2) {
            int i2;
            Z = 0.0;
            row = this.table[t2];
            Cell[] row1 = this.table[t2 - 1];
            for (i2 = 0; i2 < k; ++i2) {
                Z += row[i2].alpha * row[i2].beta;
            }
            Z *= scaling[t2];
            for (i2 = 0; i2 < k; ++i2) {
                Cell ti = row[i2];
                for (int j = 0; j < k; ++j) {
                    ti.residual[j] = label[t2] == i2 && label[t2 - 1] == j ? 1.0 - ti.expf[j] * row1[j].alpha * ti.beta / Z : 0.0 - ti.expf[j] * row1[j].alpha * ti.beta / Z;
                }
            }
        }
    }

    public static class Cell {
        public double alpha = 1.0;
        public double beta = 1.0;
        public double[] residual;
        public double[] expf;

        public Cell(int k) {
            this.residual = new double[k];
            this.expf = new double[k];
        }
    }
}

