Skip to content

Commit b7e5488

Browse files
author
Mark Hildebrand
authored
[pysvs] Allow 1-dimensional queries for search. (#17)
Previously, it was assumed that all queries were given in a two-dimensional numpy array. This relaxes that requirements to treat one-dimensional arguments as a single query.
1 parent f010178 commit b7e5488

File tree

3 files changed

+99
-5
lines changed

3 files changed

+99
-5
lines changed

bindings/python/src/common.h

+38
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ template <typename T> std::span<const T> as_span(const py_contiguous_array_t<T>&
7171
return std::span<const T>(array.data(), array.size());
7272
}
7373

74+
struct AllowVectorsTag {};
75+
76+
/// A property to pass to ``data_view`` to interpret a numpy vector as a 2D array with
77+
/// the size of the first dimension equal to zero.
78+
inline constexpr AllowVectorsTag allow_vectors{};
79+
7480
///
7581
/// Create a read-only data view over a numpy array.
7682
///
@@ -86,6 +92,38 @@ data_view(const pybind11::array_t<Eltype, pybind11::array::c_style>& data) {
8692
);
8793
}
8894

95+
///
96+
/// Create a read-only data view over a numpy matrix or vector.
97+
///
98+
/// @tparam Eltype The element type of the array.
99+
///
100+
/// @param data The numpy array to alias.
101+
/// @param property Indicate that it is okay to promote numpy vectors to matrices.
102+
///
103+
template <typename Eltype>
104+
svs::data::ConstSimpleDataView<Eltype> data_view(
105+
const pybind11::array_t<Eltype, pybind11::array::c_style>& data,
106+
AllowVectorsTag SVS_UNUSED(property)
107+
) {
108+
size_t ndims = data.ndim();
109+
// If this is a vector, interpret is a batch of queries with size 1.
110+
// The type requirement `pybind11::array::c_style` means that the underlying data is
111+
// contiguous, so we can construct a view from its pointer.
112+
if (ndims == 1) {
113+
return svs::data::ConstSimpleDataView<Eltype>(
114+
data.template unchecked<1>().data(0), 1, data.shape(0)
115+
);
116+
}
117+
118+
if (ndims != 2) {
119+
throw ANNEXCEPTION("This function can only accept numpy vectors or matrices.");
120+
}
121+
122+
return svs::data::ConstSimpleDataView<Eltype>(
123+
data.template unchecked<2>().data(0, 0), data.shape(0), data.shape(1)
124+
);
125+
}
126+
89127
///
90128
/// Create a read-write MatrixView over a numpy array.
91129
///

bindings/python/src/manager.h

+11-4
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ pybind11::tuple py_search(
3030
pybind11::array_t<QueryType, pybind11::array::c_style> queries,
3131
size_t n_neighbors
3232
) {
33-
const size_t n_queries = queries.shape(0);
34-
const auto query_data = data_view(queries);
33+
const auto query_data = data_view(queries, allow_vectors);
34+
size_t n_queries = query_data.size();
3535
auto result_idx = numpy_matrix<size_t>(n_queries, n_neighbors);
3636
auto result_dists = numpy_matrix<float>(n_queries, n_neighbors);
3737
svs::QueryResultView<size_t> q_result(
@@ -54,8 +54,12 @@ void add_search_specialization(pybind11::class_<Manager>& py_manager) {
5454
Perform a search to return the `n_neighbors` approximate nearest neighbors to the query.
5555
5656
Args:
57-
queries: Numpy Matrix representing the query batch. Individual queries are assumed to
58-
the rows of the matrix. Returned results will have a position-wise correspondence
57+
queries: Numpy Vector or Matrix representing the queries.
58+
59+
If the argument is a vector, it will be treated as a single query.
60+
61+
If the argument is a matrix, individual queries are assumed to the rows of the
62+
matrix. Returned results will have a position-wise correspondence
5963
with the queries. That is, the `N`-th row of the returned IDs and distances will
6064
correspond to the `N`-th row in the query matrix.
6165
@@ -64,6 +68,9 @@ Perform a search to return the `n_neighbors` approximate nearest neighbors to th
6468
Returns:
6569
A tuple `(I, D)` where `I` contains the `n_neighbors` approximate (or exact) nearest
6670
neighbors to the queries and `D` contains the approximate distances.
71+
72+
Note: This form is returned regardless of whether the given query was a vector or a
73+
matrix.
6774
)"
6875
);
6976
}

bindings/python/tests/test_vamana.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import os
1515
import warnings
1616

17+
import numpy as np
18+
1719
from tempfile import TemporaryDirectory
1820

1921
import pysvs
@@ -79,12 +81,50 @@ def _setup(self, loader: pysvs.VectorDataLoader):
7981
}),
8082
]
8183

84+
# Ensure that passing 1-dimensional queries works and produces the same results as
85+
# query batches.
86+
def _test_single_query(
87+
self,
88+
vamana: pysvs.Vamana,
89+
queries
90+
):
91+
92+
I_full, D_full = vamana.search(queries, 10);
93+
94+
I_single = []
95+
D_single = []
96+
for i in range(queries.shape[0]):
97+
query = queries[i, :]
98+
self.assertTrue(query.ndim == 1)
99+
I, D = vamana.search(query, 10)
100+
101+
self.assertTrue(I.ndim == 2)
102+
self.assertTrue(D.ndim == 2)
103+
self.assertTrue(I.shape == (1, 10))
104+
self.assertTrue(D.shape == (1, 10))
105+
106+
I_single.append(I)
107+
D_single.append(D)
108+
109+
I_single_concat = np.concatenate(I_single, axis = 0)
110+
D_single_concat = np.concatenate(D_single, axis = 0)
111+
self.assertTrue(np.array_equal(I_full, I_single_concat))
112+
self.assertTrue(np.array_equal(D_full, D_single_concat))
113+
114+
# Throw an error on 3-dimensional inputs.
115+
queries_3d = queries[:, :, np.newaxis]
116+
with self.assertRaises(Exception) as context:
117+
vamana.search(queries_3d, 10)
118+
119+
self.assertTrue("only accept numpy vectors or matrices" in str(context.exception))
120+
82121
def _test_basic_inner(
83122
self,
84123
vamana: pysvs.Vamana,
85124
recall_dict,
86125
num_threads: int,
87126
skip_thread_test: bool = False,
127+
test_single_query: bool = False,
88128
):
89129
# Make sure that the number of threads is propagated correctly.
90130
self.assertEqual(vamana.num_threads, num_threads)
@@ -129,6 +169,9 @@ def _test_basic_inner(
129169
if not DEBUG:
130170
self.assertTrue(isapprox(recall, expected_recall, epsilon = 0.0005))
131171

172+
if test_single_query:
173+
self._test_single_query(vamana, queries)
174+
132175
# Disable visited set.
133176
self.visited_set_enabled = False
134177

@@ -158,6 +201,7 @@ def _test_basic(self, loader, recall_dict):
158201
self._test_basic_inner(vamana, recall_dict, num_threads)
159202

160203
# Test saving and reloading.
204+
is_first = True
161205
with TemporaryDirectory() as tempdir:
162206
configdir = os.path.join(tempdir, "config")
163207
graphdir = os.path.join(tempdir, "graph")
@@ -179,8 +223,13 @@ def _test_basic(self, loader, recall_dict):
179223

180224
reloaded.num_threads = num_threads
181225
self._test_basic_inner(
182-
reloaded, recall_dict, num_threads, skip_thread_test = True
226+
reloaded,
227+
recall_dict,
228+
num_threads,
229+
skip_thread_test = True,
230+
test_single_query = is_first,
183231
)
232+
is_first = False
184233

185234
def test_basic(self):
186235
# Load the index from files.

0 commit comments

Comments
 (0)