package io.intino.sumus.chronos.models.descriptive.sequence;

import java.util.*;

import static java.util.Arrays.stream;
import static java.util.stream.IntStream.range;

public class NGram {
	private final Sequence sequence;
	private final int size;
	private final Map<Tuple, Integer> count;

	public static NGram of(Sequence sequence, int size) {
		return new NGram(sequence, size);
	}

	public NGram(Sequence sequence, int size) {
		this.sequence = sequence;
		this.size = size;
		this.count = new HashMap<>();
		this.add(sequence);
	}

	public double probabilityOf(String... tokens) {
		return probabilityOf(valuesOf(tokens));
	}

	private int[] valuesOf(String[] tokens) {
		return stream(tokens).mapToInt(sequence::indexOf).toArray();
	}

	private double probabilityOf(int[] tokens) {
		double probability = 1;
		for (int i = tokens.length; i >= size ; i--)
			probability *= probabilityOf(new Tuple(reverse(tokens, i-1, size)));
		return probability;
	}

	private int[] reverse(int[] tokens, int from, int size) {
		int[] data = new int[size];
		for (int i = 0, j = from; i < size ; i++, j--)
			data[i] = j >= 0 ? tokens[j] : -1;
		return data;
	}

	private double probabilityOf(Tuple tuple) {
		return (double) count(tuple) / count(tuple.tail());
	}

	private int count(Tuple tuple) {
		return count.getOrDefault(tuple, 0);
	}

	private void add(Sequence sequence) {
		sequence.forEach(this::add);
	}

	private void add(Sequence.Point point) {
		for (int i = size-1; i <= size; i++)
			add(Tuple.of(point, i));
	}

	private void add(Tuple tuple) {
		count.put(tuple, count(tuple)+1);
	}

	public List<Suggestion> suggestionsFor(String... tokens) {
		return suggestionsFor(valuesOf(tokens));
	}

	private List<Suggestion> suggestionsFor(int[] tokens) {
		List<Suggestion> result = new ArrayList<>();
		for (Tuple tuple : count.keySet()) {
			if (tuple.contains(tokens))
				result.add(new Suggestion(tuple));
		}
		return result;
	}

	public class Suggestion {
		public final String symbol;
		public final double probability;

		public Suggestion(Tuple tuple) {
			this.symbol = sequence.symbol(tuple.head());
			this.probability = probabilityOf(tuple);
		}

		@Override
		public String toString() {
			return "Suggestion{" +
					"symbol='" + symbol + '\'' +
					", probability=" + probability +
					'}';
		}
	}


	public static class Tuple {
		final int[] tokens;

		public Tuple(int[] tokens) {
			this.tokens = tokens;
		}

		public Tuple tail() {
			return new Tuple(Arrays.copyOfRange(tokens, 1, length()));
		}

		public int head() {
			return tokens[0];
		}

		private int length() {
			return tokens.length;
		}

		public boolean contains(int[] tokens) {
			if (tokens.length >= this.tokens.length) return false;
			return range(0, tokens.length)
					.allMatch(i -> tokens[i] == this.tokens[i]);
		}

		static Tuple of(Sequence.Point point, int size) {
			int[] data = new int[size];
			for (int i = 0; i < size; i++) {
				if (point == null) continue;
				data[i] = point.token();
				point = point.prev();
			}
			return new Tuple(data);
		}

		@Override
		public boolean equals(Object o) {
			return o instanceof Tuple && Arrays.equals(tokens, ((Tuple) o).tokens);
		}

		@Override
		public int hashCode() {
			return Arrays.hashCode(tokens);
		}

		@Override
		public String toString() {
			return Arrays.toString(tokens);
		}
	}

}
