package io.intino.sumus.time.filters;

import io.intino.sumus.time.Filter;

import static java.lang.Double.NaN;
import static java.lang.Double.isNaN;

public class Normalizer {

	public static Filter of(double min, double max) {
		return new Normalizer.MinMax(min, max)::execute;
	}

	public static Filter standard() {
		return Normalizer::z;
	}
	
	public static class MinMax {
		private final double min;
		private final double max;

		public MinMax(double min, double max) {
			this.min = min;
			this.max = max;
		}
		
		public double[] execute(double[] values) {
			double[] range = rangeOf(values);
			double[] result = new double[values.length];
			double factor = (this.max - this.min) / (range[1] - range[0]);
			for (int i = 0; i < values.length; i++)
				result[i] = isNaN(values[i]) ? NaN : (values[i] - range[0]) * factor + this.min;
			return result;
		}

		private double[] rangeOf(double[] values) {
			double[] range = new double[] {Double.MAX_VALUE, -Double.MAX_VALUE};
			for (double value : values) {
				if (isNaN(value)) continue;
				if (value < range[0]) range[0] = value;
				if (value > range[1]) range[1] = value;
			}
			return range;
		}

	}

	private static double[] z(double[] values) {
		double m = mean(values);
		double sd = standardDeviation(values, m);

		int length = values.length;
		double[] normalizedValues = new double[length];
		for (int i = 0; i < length; i++) {
			normalizedValues[i] = (values[i] - m) / sd;
		}
		return normalizedValues;
	}

	private static double mean(double[] values) {
		double sum = 0;
		int count = 0;
		for (double value : values) {
			if (isNaN(value)) continue;
			count++;
			sum += value;
		}
		return sum/count;
	}

	private static double standardDeviation(double[] values, double m) {
		double sum = 0;
		int count = 0;
		for (double value : values) {
			if (isNaN(value)) continue;
			count++;
			sum += sqr(value - m);
		}
		return Math.sqrt(sum/count);
	}

	private static double sqr(double v) {
		return v * v;
	}


}
