Skip to content

Commit d64a7c0

Browse files
committed
hnsw without parallel
1 parent 9fa6369 commit d64a7c0

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

nuq/nuq_classifier.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,7 @@ def fit(self, X, y):
220220

221221
return self
222222

223-
def predict_proba_single_(self, X, return_uncertainty=False):
224-
idx, _ = self.index.knn_query(X, k=self.n_neighbors)
225-
idx = idx[0, :]
226-
223+
def predict_proba_single_(self, idx, X, return_uncertainty=False):
227224
classes_cur, encoded = np.unique(self.y_[idx], return_inverse=True)
228225

229226
log_kernel = self.log_kernel_(self.X_[idx, :], X)
@@ -331,13 +328,18 @@ def predict_proba(self, X, return_uncertainty=False):
331328
check_is_fitted(self)
332329
X = check_array(X)
333330

331+
if self.method == "hnsw":
332+
I, _ = self.index.knn_query(X, k=self.n_neighbors)
333+
else:
334+
raise ValueError(f"Unsupported method: {self.method}")
335+
334336
classes = []
335337
ps = []
336338
log_sigma2_totals = []
337339

338340
res = Parallel(n_jobs=self.n_jobs)(
339341
delayed(self.predict_proba_single_)(
340-
X[i, :], return_uncertainty=return_uncertainty
342+
I[i, :], X[i, :], return_uncertainty=return_uncertainty
341343
)
342344
for i in tqdm(range(X.shape[0]), disable=not self.verbose)
343345
)

0 commit comments

Comments
 (0)