/*
 * Decompiled with CFR 0.152.
 */
package smile.feature.imputation;

import java.util.Arrays;
import smile.data.AbstractTuple;
import smile.data.DataFrame;
import smile.data.Tuple;
import smile.data.measure.NominalScale;
import smile.data.transform.Transform;
import smile.data.type.StructField;
import smile.data.type.StructType;
import smile.feature.imputation.SimpleImputer;
import smile.math.MathEx;
import smile.math.distance.Distance;
import smile.neighbor.KNNSearch;
import smile.neighbor.LinearSearch;
import smile.neighbor.Neighbor;

public class KNNImputer
implements Transform {
    private final int k;
    private final KNNSearch<Tuple, Tuple> knn;

    public KNNImputer(DataFrame data, int k, Distance<Tuple> distance) {
        this.k = k;
        this.knn = LinearSearch.of(data.toList(), distance);
    }

    public KNNImputer(DataFrame data, int k, String ... columns) {
        this(data, k, (Tuple x, Tuple y) -> {
            double[] xd = x.toArray(columns);
            double[] yd = y.toArray(columns);
            return MathEx.squaredDistanceWithMissingValues(xd, yd);
        });
    }

    @Override
    public Tuple apply(final Tuple x) {
        final StructType schema = x.schema();
        final Neighbor[] neighbors = this.knn.search(x, this.k);
        return new AbstractTuple(){

            @Override
            public Object get(int i) {
                Object xi = x.get(i);
                if (!SimpleImputer.isMissing(xi)) {
                    return xi;
                }
                StructField field = schema.field(i);
                if (field.type.isBoolean()) {
                    int[] vector = MathEx.omit(Arrays.stream(neighbors).mapToInt(neighbor -> ((Tuple)neighbor.key).getInt(i)).toArray(), Integer.MIN_VALUE);
                    return vector.length == 0 ? null : Boolean.valueOf(MathEx.mode(vector) != 0);
                }
                if (field.type.isChar()) {
                    int[] vector = MathEx.omit(Arrays.stream(neighbors).mapToInt(neighbor -> ((Tuple)neighbor.key).getInt(i)).toArray(), Integer.MIN_VALUE);
                    return vector.length == 0 ? null : Character.valueOf((char)MathEx.mode(vector));
                }
                if (field.measure instanceof NominalScale) {
                    int[] vector = MathEx.omit(Arrays.stream(neighbors).mapToInt(neighbor -> ((Tuple)neighbor.key).getInt(i)).toArray(), Integer.MIN_VALUE);
                    return vector.length == 0 ? null : Integer.valueOf(MathEx.mode(vector));
                }
                if (field.type.isNumeric()) {
                    double[] vector = MathEx.omit(Arrays.stream(neighbors).mapToDouble(neighbor -> ((Tuple)neighbor.key).getDouble(i)).toArray(), -2.147483648E9);
                    return vector.length == 0 ? null : Double.valueOf(MathEx.mean(vector));
                }
                return null;
            }

            @Override
            public StructType schema() {
                return schema;
            }
        };
    }
}

