@@ -220,10 +220,7 @@ def fit(self, X, y):
220
220
221
221
return self
222
222
223
- def predict_proba_single_ (self , X , return_uncertainty = False ):
224
- idx , _ = self .index .knn_query (X , k = self .n_neighbors )
225
- idx = idx [0 , :]
226
-
223
+ def predict_proba_single_ (self , idx , X , return_uncertainty = False ):
227
224
classes_cur , encoded = np .unique (self .y_ [idx ], return_inverse = True )
228
225
229
226
log_kernel = self .log_kernel_ (self .X_ [idx , :], X )
@@ -331,13 +328,18 @@ def predict_proba(self, X, return_uncertainty=False):
331
328
check_is_fitted (self )
332
329
X = check_array (X )
333
330
331
+ if self .method == "hnsw" :
332
+ I , _ = self .index .knn_query (X , k = self .n_neighbors )
333
+ else :
334
+ raise ValueError (f"Unsupported method: { self .method } " )
335
+
334
336
classes = []
335
337
ps = []
336
338
log_sigma2_totals = []
337
339
338
340
res = Parallel (n_jobs = self .n_jobs )(
339
341
delayed (self .predict_proba_single_ )(
340
- X [i , :], return_uncertainty = return_uncertainty
342
+ I [ i , :], X [i , :], return_uncertainty = return_uncertainty
341
343
)
342
344
for i in tqdm (range (X .shape [0 ]), disable = not self .verbose )
343
345
)
0 commit comments