package io.intino.sumus.chronos.models.descriptive.sequence;

import static java.util.Arrays.stream;
import static java.util.stream.IntStream.range;

public class TransitionGraph {
	private final Sequence sequence;
	private final int[][] transitions;

	public static TransitionGraph of(Sequence sequence) {
		TransitionGraph graph = new TransitionGraph(sequence.symbols());
		for (Sequence.Point point : sequence)
			graph.put(point.token());
		return graph;
	}

	private int last = -1;
	private void put(int value) {
		if (last >= 0) transitions[last][value]++;
		last = value;
	}

	private TransitionGraph(String[] states) {
		this.sequence = Sequence.of(states);
		this.transitions = new int[states.length][states.length];
	}

	public int size() {
		return sequence.symbols().length;
	}

	public String[] states() {
		return sequence.symbols();
	}

	private double[] stateProbabilities() {
		double[] result = new double[size()];
		double total = 0;
		for (int i = 0; i < size(); i++)
			for (int j = 0; j < size(); j++) {
				result[i] += transitions[i][j];
				total += transitions[i][j];
			}
		if (total == 0) return result;
		for (int i = 0; i < size(); i++) result[i] = result[i] / total;
		return result;
	}

	public int[][] transitions() {
		return transitions;
	}

	public int[] transitions(String state) {
		return transitions[sequence.indexOf(state)];
	}

	public double[][] transitionProbabilities() {
		double[][] result = new double[size()][];
		for (int row = 0; row < size(); row++)
			result[row] = transitionProbabilities(row);
		return result;
	}

	public double[] transitionProbabilities(String state) {
		return transitionProbabilities(sequence.indexOf(state));
	}

	private double[] transitionProbabilities(int row) {
		double[] result = new double[size()];
		double sum = sum(row);
		for (int i = 0; i < size(); i++)
			result[i] = sum > 0 ? transitions[row][i] / sum : (i==row) ? 1. : 0.;
		return result;
	}

	public double[] walk(int steps) {
		double[] stateProbabilities = stateProbabilities(); //FIXME
		double[][] randomWalks = power(transitionProbabilities(), steps);
		double[] result = new double[size()];
		for (int i = 0; i < size(); i++)
			for (int j = 0; j < size(); j++)
				result[j] += stateProbabilities[i] * randomWalks[i][j];
		return result;
	}


	private double[][] power(double[][] matrix, int steps) {
		double[][] result = matrix;
		for (int i = 1; i < steps; i++)
			result = multiply(result, matrix);
		return result;
	}

	private double[][] multiply(double[][] a, double[][] b) {
		double[][] c = new double[size()][size()];
		range(0,size()).parallel().forEach(i -> {
			for (int k = 0; k < size(); k++)
				for (int j = 0; j < size(); j++)
					c[i][j] += a[i][k] * b[k][j];
		});
		return c;
	}

	private int sum(int row) {
		return stream(transitions[row]).sum();
	}


}
