14
14
import os
15
15
import warnings
16
16
17
+ import numpy as np
18
+
17
19
from tempfile import TemporaryDirectory
18
20
19
21
import pysvs
@@ -79,12 +81,50 @@ def _setup(self, loader: pysvs.VectorDataLoader):
79
81
}),
80
82
]
81
83
84
+ # Ensure that passing 1-dimensional queries works and produces the same results as
85
+ # query batches.
86
+ def _test_single_query (
87
+ self ,
88
+ vamana : pysvs .Vamana ,
89
+ queries
90
+ ):
91
+
92
+ I_full , D_full = vamana .search (queries , 10 );
93
+
94
+ I_single = []
95
+ D_single = []
96
+ for i in range (queries .shape [0 ]):
97
+ query = queries [i , :]
98
+ self .assertTrue (query .ndim == 1 )
99
+ I , D = vamana .search (query , 10 )
100
+
101
+ self .assertTrue (I .ndim == 2 )
102
+ self .assertTrue (D .ndim == 2 )
103
+ self .assertTrue (I .shape == (1 , 10 ))
104
+ self .assertTrue (D .shape == (1 , 10 ))
105
+
106
+ I_single .append (I )
107
+ D_single .append (D )
108
+
109
+ I_single_concat = np .concatenate (I_single , axis = 0 )
110
+ D_single_concat = np .concatenate (D_single , axis = 0 )
111
+ self .assertTrue (np .array_equal (I_full , I_single_concat ))
112
+ self .assertTrue (np .array_equal (D_full , D_single_concat ))
113
+
114
+ # Throw an error on 3-dimensional inputs.
115
+ queries_3d = queries [:, :, np .newaxis ]
116
+ with self .assertRaises (Exception ) as context :
117
+ vamana .search (queries_3d , 10 )
118
+
119
+ self .assertTrue ("only accept numpy vectors or matrices" in str (context .exception ))
120
+
82
121
def _test_basic_inner (
83
122
self ,
84
123
vamana : pysvs .Vamana ,
85
124
recall_dict ,
86
125
num_threads : int ,
87
126
skip_thread_test : bool = False ,
127
+ test_single_query : bool = False ,
88
128
):
89
129
# Make sure that the number of threads is propagated correctly.
90
130
self .assertEqual (vamana .num_threads , num_threads )
@@ -129,6 +169,9 @@ def _test_basic_inner(
129
169
if not DEBUG :
130
170
self .assertTrue (isapprox (recall , expected_recall , epsilon = 0.0005 ))
131
171
172
+ if test_single_query :
173
+ self ._test_single_query (vamana , queries )
174
+
132
175
# Disable visited set.
133
176
self .visited_set_enabled = False
134
177
@@ -158,6 +201,7 @@ def _test_basic(self, loader, recall_dict):
158
201
self ._test_basic_inner (vamana , recall_dict , num_threads )
159
202
160
203
# Test saving and reloading.
204
+ is_first = True
161
205
with TemporaryDirectory () as tempdir :
162
206
configdir = os .path .join (tempdir , "config" )
163
207
graphdir = os .path .join (tempdir , "graph" )
@@ -179,8 +223,13 @@ def _test_basic(self, loader, recall_dict):
179
223
180
224
reloaded .num_threads = num_threads
181
225
self ._test_basic_inner (
182
- reloaded , recall_dict , num_threads , skip_thread_test = True
226
+ reloaded ,
227
+ recall_dict ,
228
+ num_threads ,
229
+ skip_thread_test = True ,
230
+ test_single_query = is_first ,
183
231
)
232
+ is_first = False
184
233
185
234
def test_basic (self ):
186
235
# Load the index from files.
0 commit comments