diff --git a/scanpy/neighbors/__init__.py b/scanpy/neighbors/__init__.py index c2977682e8..e0395dfeaa 100644 --- a/scanpy/neighbors/__init__.py +++ b/scanpy/neighbors/__init__.py @@ -315,7 +315,9 @@ def compute_neighbors_umap( return knn_indices, knn_dists, forest -def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int): +def compute_neighbors_rapids( + X: np.ndarray, n_neighbors: int, metric: _Metric = 'euclidean' +): """Compute nearest neighbors using RAPIDS cuml. Parameters @@ -324,6 +326,9 @@ def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int): The data to compute nearest neighbors for. n_neighbors The number of neighbors to use. + metric + The metric to use to compute distances in high dimensional space. + This string must match a valid predefined metric in RAPIDS cuml. Returns ------- @@ -331,7 +336,7 @@ def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int): """ from cuml.neighbors import NearestNeighbors - nn = NearestNeighbors(n_neighbors=n_neighbors) + nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric) X_contiguous = np.ascontiguousarray(X, dtype=np.float32) nn.fit(X_contiguous) knn_distsq, knn_indices = nn.kneighbors(X_contiguous) @@ -755,10 +760,6 @@ def compute_neighbors( logg.warning(f'n_obs too small: adjusting to `n_neighbors = {n_neighbors}`') if method == 'umap' and not knn: raise ValueError('`method = \'umap\' only with `knn = True`.') - if method == 'rapids' and metric != 'euclidean': - raise ValueError( - "`method` 'rapids' only supports the 'euclidean' `metric`." - ) if method not in {'umap', 'gauss', 'rapids'}: raise ValueError("`method` needs to be 'umap', 'gauss', or 'rapids'.") if self._adata.shape[0] >= 10000 and not knn: @@ -782,7 +783,9 @@ def compute_neighbors( else: self._distances = _distances elif method == 'rapids': - knn_indices, knn_distances = compute_neighbors_rapids(X, n_neighbors) + knn_indices, knn_distances = compute_neighbors_rapids( + X, n_neighbors, metric=metric + ) else: # non-euclidean case and approx nearest neighbors if X.shape[0] < 4096: