Skip to content

Commit

Permalink
Check python brute-force knn inputs (#1537)
Browse files Browse the repository at this point in the history
The input validation code wasn't being triggered for the python bfknn api, causing invalid output when passed col-major inputs. Fix.

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1537
  • Loading branch information
benfred authored May 19, 2023
1 parent 0154e8e commit db96e8b
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 19 deletions.
14 changes: 3 additions & 11 deletions python/pylibraft/pylibraft/common/device_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,12 +89,8 @@ def c_contiguous(self):
"""
Is the current device_ndarray laid out in row-major format?
"""
array_interface = self.ndarray_.__array_interface__
strides = self.strides
return (
strides is None
or array_interface["strides"][1] == self.dtype.itemsize
)
return strides is None or strides[1] == self.dtype.itemsize

@property
def f_contiguous(self):
Expand Down Expand Up @@ -125,11 +121,7 @@ def strides(self):
Strides of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
return (
None
if "strides" not in array_interface
else array_interface["strides"]
)
return array_interface.get("strides")

@property
def __cuda_array_interface__(self):
Expand Down
6 changes: 6 additions & 0 deletions python/pylibraft/pylibraft/neighbors/brute_force.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ from pylibraft.distance.distance_type cimport DistanceType
# TODO: Centralize this

from pylibraft.distance.pairwise_distance import DISTANCE_TYPES
from pylibraft.neighbors.common import _check_input_array

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
Expand Down Expand Up @@ -143,6 +144,11 @@ def knn(dataset, queries, k=None, indices=None, distances=None,
raise ValueError("Argument k must be specified if both indices "
"and distances arg is None")

# we require c-contiguous (rowmajor) inputs here
_check_input_array(dataset_cai, [np.dtype("float32")])
_check_input_array(queries_cai, [np.dtype("float32")],
exp_cols=dataset_cai.shape[1])

n_queries = queries_cai.shape[0]

if indices is None:
Expand Down
23 changes: 19 additions & 4 deletions python/pylibraft/pylibraft/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,8 @@
],
)
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("dtype", [np.float32])
def test_knn(
n_index_rows, n_query_rows, n_cols, k, inplace, metric, order, dtype
):
def test_knn(n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype):
index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype)
queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype)

Expand Down Expand Up @@ -94,3 +91,21 @@ def test_knn(
np.testing.assert_allclose(
cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4
)


def test_knn_check_col_major_inputs():
# make sure that we get an exception if passed col-major inputs,
# instead of returning incorrect results
cp = pytest.importorskip("cupy")
n_index_rows, n_query_rows, n_cols = 128, 16, 32
index = cp.random.random_sample((n_index_rows, n_cols), dtype="float32")
queries = cp.random.random_sample((n_query_rows, n_cols), dtype="float32")

with pytest.raises(ValueError):
knn(cp.asarray(index, order="F"), queries, k=4)

with pytest.raises(ValueError):
knn(index, cp.asarray(queries, order="F"), k=4)

# shouldn't throw an exception with c-contiguous inputs
knn(index, queries, k=4)
5 changes: 1 addition & 4 deletions python/pylibraft/pylibraft/test/test_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from pylibraft.common import DeviceResources, Stream, device_ndarray
from pylibraft.distance import pairwise_distance

try:
import cupy
except ImportError:
pytest.skip(reason="cupy not installed.")
cupy = pytest.importorskip("cupy")


@pytest.mark.parametrize("stream", [cupy.cuda.Stream().ptr, Stream()])
Expand Down

0 comments on commit db96e8b

Please sign in to comment.