Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add slow high-precision mode to KNN #3304

Merged
merged 10 commits into from
Jan 26, 2021
79 changes: 76 additions & 3 deletions python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ from cuml.common.array_sparse import SparseCumlArray
from cuml.common.doc_utils import generate_docstring
from cuml.common.doc_utils import insert_into_docstring
from cuml.common.import_utils import has_scipy
from cuml.common.input_utils import input_to_cupy_array
from cuml.common import input_to_cuml_array
from cuml.neighbors.ann cimport *
from cuml.common.sparse_utils import is_sparse
Expand Down Expand Up @@ -423,7 +424,8 @@ class NearestNeighbors(Base):
X=None,
n_neighbors=None,
return_distance=True,
convert_dtype=True
convert_dtype=True,
two_pass_precision=False
) -> typing.Union[CumlArray, typing.Tuple[CumlArray, CumlArray]]:
"""
Query the GPU index for the k nearest neighbors of column vectors in X.
Expand All @@ -443,6 +445,25 @@ class NearestNeighbors(Base):
When set to True, the kneighbors method will automatically
convert the inputs to np.float32.

two_pass_precision : bool, optional (default = False)
When set to True, a slow second pass will be used to improve the
precision of results returned for searches using L2-derived
metrics. FAISS uses the Euclidean distance decomposition trick to
compute distances in this case, which may result in numerical
errors for certain data. In particular, when several samples
are close to the query sample (relative to typical inter-sample
distances), numerical instability may cause the computed distance
between the query and itself to be larger than the computed
distance between the query and another sample. As a result, the
query is not returned as the nearest neighbor to itself. If this
flag is set to true, distances to the query vectors will be
recomputed with high precision for all retrieved samples, and the
results will be re-sorted accordingly. Note that for large values
of k or large numbers of query vectors, this correction becomes
impractical in terms of both runtime and memory. It should be used
with care and only when strictly necessary (when precise results
are critical and samples may be tightly clustered).

Returns
-------
distances : {}
Expand All @@ -453,10 +474,12 @@ class NearestNeighbors(Base):
The indices of the k-nearest neighbors for each column vector in X
"""

return self._kneighbors(X, n_neighbors, return_distance, convert_dtype)
return self._kneighbors(X, n_neighbors, return_distance, convert_dtype,
two_pass_precision=two_pass_precision)

def _kneighbors(self, X=None, n_neighbors=None, return_distance=True,
convert_dtype=True, _output_type=None):
convert_dtype=True, _output_type=None,
two_pass_precision=False):
"""
Query the GPU index for the k nearest neighbors of column vectors in X.

Expand All @@ -482,6 +505,25 @@ class NearestNeighbors(Base):
When set to True, the class self.output_type is overwritten
and this method returns the output as a cumlarray

two_pass_precision : bool, optional (default = False)
When set to True, a slow second pass will be used to improve the
precision of results returned for searches using L2-derived
metrics. FAISS uses the Euclidean distance decomposition trick to
compute distances in this case, which may result in numerical
errors for certain data. In particular, when several samples
are close to the query sample (relative to typical inter-sample
distances), numerical instability may cause the computed distance
between the query and itself to be larger than the computed
distance between the query and another sample. As a result, the
query is not returned as the nearest neighbor to itself. If this
flag is set to true, distances to the query vectors will be
recomputed with high precision for all retrieved samples, and the
results will be re-sorted accordingly. Note that for large values
of k or large numbers of query vectors, this correction becomes
impractical in terms of both runtime and memory. It should be used
with care and only when strictly necessary (when precise results
are critical and samples may be tightly clustered).

Returns
-------
distances: cupy ndarray
Expand Down Expand Up @@ -525,6 +567,37 @@ class NearestNeighbors(Base):
out_type = _output_type \
if _output_type is not None else self._get_output_type(X)

if two_pass_precision:
metric, expanded = self._build_metric_type(self.metric)
metric_is_l2_based = (
metric == MetricType.METRIC_L2 or
(metric == MetricType.METRIC_Lp and self.p == 2)
)

# FAISS employs imprecise distance algorithm only for L2-based
# metrics
if metric_is_l2_based:
X = input_to_cupy_array(X).array
I_cparr = I_ndarr.to_output('cupy')

self_diff = X[I_cparr] - X[:, cp.newaxis, :]
if expanded:
precise_distances = cp.sum(
self_diff * self_diff, axis=2
)
else:
precise_distances = cp.linalg.norm(self_diff, axis=2)

correct_order = cp.argsort(precise_distances, axis=1)

D_cparr = cp.take_along_axis(precise_distances,
correct_order,
axis=1)
I_cparr = cp.take_along_axis(I_cparr, correct_order, axis=1)

D_ndarr = cuml.common.input_to_cuml_array(D_cparr).array
I_ndarr = cuml.common.input_to_cuml_array(I_cparr).array

I_ndarr = I_ndarr.to_output(out_type)
D_ndarr = D_ndarr.to_output(out_type)

Expand Down
64 changes: 64 additions & 0 deletions python/cuml/test/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import cudf
import pandas as pd
import numpy as np
from numpy.testing import assert_array_equal, assert_allclose
from scipy.sparse import isspmatrix_csr

import sklearn
Expand All @@ -54,6 +55,69 @@ def valid_metrics(algo="brute", cuml_algo=None):
return [value for value in cuml_metrics if value in sklearn_metrics]


def metric_p_combinations():
for metric in valid_metrics():
yield metric, 2
if metric in ("minkowski", "lp"):
yield metric, 3


@pytest.mark.parametrize("datatype", ["dataframe", "numpy"])
@pytest.mark.parametrize("metric_p", metric_p_combinations())
@pytest.mark.parametrize("nrows", [1000, stress_param(10000)])
@pytest.mark.skipif(not has_scipy(), reason="Skipping test_self_neighboring"
" because Scipy is missing")
def test_self_neighboring(datatype, metric_p, nrows):
"""Test that searches using an indexed vector itself return sensible
results for that vector

For L2-derived metrics, this specifically exercises the slow high-precision
mode used to correct for approximation errors in L2 computation during NN
searches.
"""
ncols = 1000
n_clusters = 10
n_neighbors = 3

metric, p = metric_p

if not has_scipy():
pytest.skip('Skipping test_neighborhood_predictions because ' +
'Scipy is missing')

X, y = make_blobs(n_samples=nrows, centers=n_clusters,
n_features=ncols, random_state=0)

if datatype == "dataframe":
X = cudf.DataFrame(X)

knn_cu = cuKNN(metric=metric, n_neighbors=n_neighbors)
knn_cu.fit(X)
neigh_dist, neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors,
return_distance=True,
two_pass_precision=True)

if datatype == 'dataframe':
assert isinstance(neigh_ind, cudf.DataFrame)
neigh_ind = neigh_ind.as_gpu_matrix().copy_to_host()
neigh_dist = neigh_dist.as_gpu_matrix().copy_to_host()
else:
assert isinstance(neigh_ind, np.ndarray)

neigh_ind = neigh_ind[:, 0]
neigh_dist = neigh_dist[:, 0]

assert_array_equal(
neigh_ind,
np.arange(0, neigh_dist.shape[0]),
)
assert_allclose(
neigh_dist,
np.zeros(neigh_dist.shape, dtype=neigh_dist.dtype),
atol=1e-4
)


@pytest.mark.parametrize("datatype", ["dataframe", "numpy"])
@pytest.mark.parametrize("nrows", [500, 1000, 10000])
@pytest.mark.parametrize("ncols", [128, 1024])
Expand Down