/*
 * Decompiled with CFR 0.152.
 */
package smile.neighbor;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import smile.math.MathEx;
import smile.neighbor.KNNSearch;
import smile.neighbor.Neighbor;
import smile.neighbor.NeighborBuilder;
import smile.neighbor.RNNSearch;
import smile.sort.HeapSelect;

public class KDTree<E>
implements KNNSearch<double[], E>,
RNNSearch<double[], E>,
Serializable {
    private static final long serialVersionUID = 2L;
    private final double[][] keys;
    private final E[] data;
    private final Node root;
    private final int[] index;

    public KDTree(double[][] key, E[] data) {
        if (key.length != data.length) {
            throw new IllegalArgumentException("The array size of keys and data are different.");
        }
        this.keys = key;
        this.data = data;
        int n = key.length;
        this.index = new int[n];
        for (int i = 0; i < n; ++i) {
            this.index[i] = i;
        }
        int d = this.keys[0].length;
        double[] lowerBound = new double[d];
        double[] upperBound = new double[d];
        this.root = this.buildNode(0, n, lowerBound, upperBound);
    }

    public static KDTree<double[]> of(double[][] data) {
        return new KDTree<double[]>(data, (E[])data);
    }

    public String toString() {
        return "KD-Tree";
    }

    private Node buildNode(int begin, int end, double[] lowerBound, double[] upperBound) {
        int d = this.keys[0].length;
        Node node = new Node();
        node.count = end - begin;
        node.index = begin;
        double[] key = this.keys[this.index[begin]];
        System.arraycopy(key, 0, lowerBound, 0, d);
        System.arraycopy(key, 0, upperBound, 0, d);
        for (int i = begin + 1; i < end; ++i) {
            key = this.keys[this.index[i]];
            for (int j = 0; j < d; ++j) {
                double c = key[j];
                if (lowerBound[j] > c) {
                    lowerBound[j] = c;
                }
                if (!(upperBound[j] < c)) continue;
                upperBound[j] = c;
            }
        }
        double maxRadius = -1.0;
        for (int i = 0; i < d; ++i) {
            double radius = (upperBound[i] - lowerBound[i]) / 2.0;
            if (!(radius > maxRadius)) continue;
            maxRadius = radius;
            node.split = i;
            node.cutoff = (upperBound[i] + lowerBound[i]) / 2.0;
        }
        if (MathEx.isZero(maxRadius, 1.0E-8)) {
            node.upper = null;
            node.lower = null;
            return node;
        }
        int i1 = begin;
        int i2 = end - 1;
        int size = 0;
        while (i1 <= i2) {
            boolean i2Good;
            boolean i1Good = this.keys[this.index[i1]][node.split] < node.cutoff;
            boolean bl = i2Good = this.keys[this.index[i2]][node.split] >= node.cutoff;
            if (!i1Good && !i2Good) {
                int temp = this.index[i1];
                this.index[i1] = this.index[i2];
                this.index[i2] = temp;
                i2Good = true;
                i1Good = true;
            }
            if (i1Good) {
                ++i1;
                ++size;
            }
            if (!i2Good) continue;
            --i2;
        }
        if (size == 0 || size == node.count) {
            node.upper = null;
            node.lower = null;
            return node;
        }
        node.lower = this.buildNode(begin, begin + size, lowerBound, upperBound);
        node.upper = this.buildNode(begin + size, end, lowerBound, upperBound);
        return node;
    }

    private void search(double[] q, Node node, NeighborBuilder<double[], E> neighbor) {
        if (node.isLeaf()) {
            for (int idx = node.index; idx < node.index + node.count; ++idx) {
                double distance;
                int i = this.index[idx];
                if (q == this.keys[i] || !((distance = MathEx.distance(q, this.keys[i])) < neighbor.distance)) continue;
                neighbor.index = i;
                neighbor.distance = distance;
            }
        } else {
            Node further;
            Node nearer;
            double diff = q[node.split] - node.cutoff;
            if (diff < 0.0) {
                nearer = node.lower;
                further = node.upper;
            } else {
                nearer = node.upper;
                further = node.lower;
            }
            this.search(q, nearer, neighbor);
            if (neighbor.distance >= Math.abs(diff)) {
                this.search(q, further, neighbor);
            }
        }
    }

    private void search(double[] q, Node node, HeapSelect<NeighborBuilder<double[], E>> heap) {
        if (node.isLeaf()) {
            for (int idx = node.index; idx < node.index + node.count; ++idx) {
                int i = this.index[idx];
                if (q == this.keys[i]) continue;
                double distance = MathEx.distance(q, this.keys[i]);
                NeighborBuilder<double[], E> datum = heap.peek();
                if (!(distance < datum.distance)) continue;
                datum.distance = distance;
                datum.index = i;
                heap.heapify();
            }
        } else {
            Node further;
            Node nearer;
            double diff = q[node.split] - node.cutoff;
            if (diff < 0.0) {
                nearer = node.lower;
                further = node.upper;
            } else {
                nearer = node.upper;
                further = node.lower;
            }
            this.search(q, nearer, heap);
            if (heap.peek().distance >= Math.abs(diff)) {
                this.search(q, further, heap);
            }
        }
    }

    private void search(double[] q, Node node, double radius, List<Neighbor<double[], E>> neighbors) {
        if (node.isLeaf()) {
            for (int idx = node.index; idx < node.index + node.count; ++idx) {
                double distance;
                int i = this.index[idx];
                if (q == this.keys[i] || !((distance = MathEx.distance(q, this.keys[i])) <= radius)) continue;
                neighbors.add(new Neighbor<double[], E>(this.keys[i], this.data[i], i, distance));
            }
        } else {
            Node further;
            Node nearer;
            double diff = q[node.split] - node.cutoff;
            if (diff < 0.0) {
                nearer = node.lower;
                further = node.upper;
            } else {
                nearer = node.upper;
                further = node.lower;
            }
            this.search(q, nearer, radius, neighbors);
            if (radius >= Math.abs(diff)) {
                this.search(q, further, radius, neighbors);
            }
        }
    }

    @Override
    public Neighbor<double[], E> nearest(double[] q) {
        NeighborBuilder neighbor = new NeighborBuilder();
        this.search(q, this.root, neighbor);
        neighbor.key = this.keys[neighbor.index];
        neighbor.value = this.data[neighbor.index];
        return neighbor.toNeighbor();
    }

    @Override
    public Neighbor<double[], E>[] search(double[] q, int k) {
        if (k <= 0) {
            throw new IllegalArgumentException("Invalid k: " + k);
        }
        if (k > this.keys.length) {
            throw new IllegalArgumentException("Neighbor array length is larger than the dataset size");
        }
        HeapSelect<NeighborBuilder<double[], E>> heap = new HeapSelect<NeighborBuilder<double[], E>>(NeighborBuilder.class, k);
        for (int i = 0; i < k; ++i) {
            heap.add(new NeighborBuilder());
        }
        this.search(q, this.root, heap);
        heap.sort();
        return (Neighbor[])Arrays.stream((NeighborBuilder[])heap.toArray()).map(neighbor -> {
            neighbor.key = this.keys[neighbor.index];
            neighbor.value = this.data[neighbor.index];
            return neighbor.toNeighbor();
        }).toArray(Neighbor[]::new);
    }

    @Override
    public void search(double[] q, double radius, List<Neighbor<double[], E>> neighbors) {
        if (radius <= 0.0) {
            throw new IllegalArgumentException("Invalid radius: " + radius);
        }
        this.search(q, this.root, radius, neighbors);
    }

    static class Node
    implements Serializable {
        int count;
        int index;
        int split;
        double cutoff;
        Node lower;
        Node upper;

        Node() {
        }

        boolean isLeaf() {
            return this.lower == null && this.upper == null;
        }
    }
}

