package io.intino.sumus.chronos;

import io.intino.sumus.chronos.filters.*;
import io.intino.sumus.chronos.models.descriptive.timeseries.Distribution;

import java.time.Instant;
import java.util.Arrays;
import java.util.Iterator;
import java.util.stream.Stream;

import static io.intino.sumus.chronos.Magnitude.Model.Default;
import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;
import static java.util.Arrays.*;
import static java.util.stream.IntStream.iterate;
import static java.util.stream.IntStream.range;

public class TimeSeries implements Iterable<TimeSeries.Point> {
	public static TimeSeries Null = new TimeSeries(Default, new Instant[0], new double[0]);
	public final Magnitude.Model model;
	public final Instant[] instants;
	public final double[] values;

	public TimeSeries(Magnitude.Model model, Instant[] instants, double[] values) {
		assert instants.length == values.length;
		this.instants = instants;
		this.values = values;
		this.model = model;
	}

	public boolean isEmpty() {
		return length() == 0;
	}

	public int length() {
		return instants.length;
	}

	public boolean isBefore(TimeSeries timeSeries) {
		if (timeSeries.length() == 0) return false;
		return isEndOfTimes(timeSeries.firstInstant());
	}

	public boolean isAfter(TimeSeries timeSeries) {
		if (timeSeries.length() == 0) return false;
		return isBeginningOfTimes(timeSeries.lastInstant());
	}

	public String unit() {
		return model.unit;
	}

	public String symbol() {
		return model.symbol;
	}

	public double min() {
		return model.min;
	}

	public double max() {
		return model.max;
	}

	public Point first() {
		return point(0);
	}

	public Point last() {
		return point(lastIndex());
	}

	public Point at(Instant instant) {
		return point(indexOf(instant));
	}

	public TimeSeries head(int length) {
		return sub(0, limitHigh(length));
	}

	public TimeSeries tail(int length) {
		return sub(limitLow(length() - length), length());
	}

	public TimeSeries from(Instant instant) {
		if (isEndOfTimes(instant)) return empty();
		return sub(indexOf(instant), length());
	}

	public TimeSeries from(Instant instant, int length) {
		if (isEndOfTimes(instant)) return empty();
		int from = indexOf(instant);
		return length > 0 ?
				sub(from, limitHigh(from + length)) :
				sub(limitLow(from + 1 + length), limitHigh(from + 1));
	}

	public TimeSeries from(Instant instant, Instant to) {
		if (isEndOfTimes(instant)) return empty();
		if (isEndOfTimes(to)) return from(instant);
		return sub(indexOf(instant), indexOf(to));
	}

	public TimeSeries to(Instant instant) {
		if (isEndOfTimes(instant)) return this;
		return sub(0, indexOf(instant));
	}

	private Point point(int index) {
		return isInRange(index) ? new Point(index) : null;
	}

	private boolean isBeginningOfTimes(Instant instant) {
		return firstInstant().isAfter(instant);
	}

	private boolean isEndOfTimes(Instant instant) {
		return lastInstant().isBefore(instant);
	}

	private boolean isInRange(int index) {
		return index >= 0 && index < length();
	}


	private int lastIndex() {
		return length() - 1;
	}

	private int limitLow(int offset) {
		return Math.max(0, offset);
	}

	private int limitHigh(int offset) {
		return Math.min(offset, length());
	}

	private Instant firstInstant() {
		return instants[0];
	}

	private Instant lastInstant() {
		return instants[lastIndex()];
	}

	private TimeSeries sub(int from, int to) {
		return isInRange(from) && isInRange(to - 1) ? new TimeSeries(Default, instants(from, to), values(from, to)) : empty();
	}

	public static TimeSeries empty() {
		return new TimeSeries(Default, new Instant[0], new double[0]);
	}

	private int indexOf(Instant instant) {
		return Math.abs(binarySearch(instants, instant));
	}

	private Instant[] instants(int from, int to) {
		return copyOfRange(instants, from, to);
	}

	private double[] values(int from, int to) {
		return copyOfRange(values, from, to);
	}

	@Override
	public Iterator<Point> iterator() {
		return new Iterator<>() {
			int index = 0;

			@Override
			public boolean hasNext() {
				return index < instants.length;
			}

			@Override
			public Point next() {
				return new Point(index++);
			}
		};
	}

	public double sum() {
		return stream(values).filter(v -> v == v).sum();
	}

	public double average() {
		return stream(values).filter(v -> v == v).average().orElse(0.);
	}

	public double probabilityOf(double value) {
		return distribution().probabilityOf(value);
	}

	public double probabilityAt(Point point) {
		return point != null ? probabilityOf(point.value()) : NaN;
	}


	public class Point {
		private final int index;

		public Point(int index) {
			this.index = index;
		}

		public Instant instant() {
			return instants[index];
		}

		public double value() {
			return values[index];
		}

		public Stream<Point> forward() {
			return range(index, length()).mapToObj(TimeSeries.this::point);
		}

		public Stream<Point> backward() {
			return iterate(index, i -> i >= 0, i -> i - 1).mapToObj(TimeSeries.this::point);
		}

		@Override
		public String toString() {
			return instants[index] + ":" + value();
		}

		public Point next() {
			return step(1);
		}

		public Point prev() {
			return step(-1);
		}

		public Point step(int value) {
			return point(index + value);
		}
	}

	private Distribution distribution = null;

	public Distribution distribution() {
		if (distribution == null) distribution = Distribution.of(this);
		return distribution;
	}

	public TimeSeries concat(TimeSeries timeSeries) {
		assert this.isBefore(timeSeries) && model.equals(timeSeries.model);
		return new TimeSeries(model, concat(instants, timeSeries.instants), concat(values, timeSeries.values));
	}

	static double[] concat(double[] a, double[] b) {
		double[] result = new double[a.length + b.length];
		System.arraycopy(a, 0, result, 0, a.length);
		System.arraycopy(b, 0, result, a.length, b.length);
		return result;
	}

	static Instant[] concat(Instant[] a, Instant[] b) {
		Instant[] result = new Instant[a.length + b.length];
		System.arraycopy(a, 0, result, 0, a.length);
		System.arraycopy(b, 0, result, a.length, b.length);
		return result;
	}


	public TimeSeries plus(TimeSeries series) {
		assert Arrays.equals(instants, series.instants);
		return create(model, (a, b) -> isNaN(a) ? b : isNaN(b) ? a : a + b, series);
	}

	public TimeSeries minus(TimeSeries series) {
		assert Arrays.equals(instants, series.instants);
		return create(model, (a, b) -> isNaN(a) ? b : isNaN(b) ? a : a - b, series);
	}

	public TimeSeries negate() {
		return times(-1);
	}

	public TimeSeries times(double factor) {
		return create(model, a -> a * factor);
	}

	public TimeSeries square() {
		return times(this);
	}

	public TimeSeries times(TimeSeries timeSeries) {
		return create(model, (a, b) -> a * b, timeSeries);
	}

	public TimeSeries dividedBy(TimeSeries series) {
		assert Arrays.equals(instants, series.instants);
		return create(model, (a, b) -> isNaN(a) ? b : isNaN(b) ? a : a / b, series);
	}

	public TimeSeries inverse() {
		Magnitude.Model model = this.model.symbol("1/" + this.model.symbol).unit("1/" + this.model.unit);
		return create(model, value -> value == 0 ? NaN : 1 / value);
	}

	public TimeSeries differential() {
		return execute(Default, new Differential());
	}

	public TimeSeries ratio() {
		return create(model.updateWith("max=1:symbol:unit"), value -> value / model.max);
	}

	public TimeSeries percentage() {
		return create(model.updateWith("max=100:unit:symbol=%"), value -> value * 100);
	}

	public TimeSeries rateOfGrowth() {
		return execute(Default, new RateOfGrowth());
	}

	public TimeSeries abs() {
		return create(model, Math::abs);
	}

	public TimeSeries log() {
		return create(model, Math::log);
	}

	public TimeSeries movingAverage(int observations) {
		return execute(Default, MovingAverage.of(observations));
	}

	public TimeSeries movingAverage(double smoothingFactor) {
		return execute(Default, MovingAverage.of(smoothingFactor));
	}

	public TimeSeries denoise(Denoise.Mode mode) {
		return execute(Default, new Denoise(mode));
	}

	public TimeSeries normalize() {
		return normalize(0, 1);
	}

	public TimeSeries normalize(double min, double max) {
		return execute(Default, Normalizer.of(min, max));
	}

	public TimeSeries standardize() {
		return execute(Default, Normalizer.standard());
	}

	public TimeSeries cumulativeMovingAverage(int period) {
		return average(1. / period);
	}

	public TimeSeries exponentialMovingAverage(int period) {
		return average(2. / (period + 1));
	}

	private TimeSeries average(double weight) {
		double[] result = new double[values.length];
		result[0] = average();
		for (int i = 1; i < values.length; i++)
			result[i] = values[i] * weight + result[i - 1] * (1 - weight);
		return new TimeSeries(Default, instants, result);
	}

	@Override
	public boolean equals(Object o) {
		if (o == null || getClass() != o.getClass()) return false;
		return this == o || equals((TimeSeries) o);
	}

	private boolean equals(TimeSeries series) {
		return Arrays.equals(instants, series.instants) && Arrays.equals(values, series.values);
	}

	@Override
	public int hashCode() {
		int result = Arrays.hashCode(instants);
		result = 31 * result + Arrays.hashCode(values);
		return result;
	}

	@Override
	public String toString() {
		StringBuilder sb = new StringBuilder();
		for (int i = 0; i < instants.length; i++) sb.append(instants[i]).append('\t').append(values[i]).append('\n');
		return sb.toString();
	}

	private TimeSeries execute(Magnitude.Model model, Filter filter) {
		return new TimeSeries(model, instants, filter.calculate(values));
	}

	private TimeSeries create(Magnitude.Model model, Filter.UnaryOperator operator) {
		return execute(model, filterOf(operator));
	}

	private TimeSeries create(Magnitude.Model model, Filter.BinaryOperator operator, TimeSeries series) {
		return execute(model, filterOf(operator, series));
	}

	private Filter filterOf(Filter.UnaryOperator operator) {
		return values -> {
			double[] result = new double[instants.length];
			setAll(result, i -> operator.calculate(values[i]));
			return result;
		};
	}

	private Filter filterOf(Filter.BinaryOperator operator, TimeSeries series) {
		return values -> {
			double[] result = new double[instants.length];
			setAll(result, i -> operator.calculate(values[i], series.values[i]));
			return result;
		};
	}


}
