Skip to content

Commit

Permalink
Neighbors: Support additional RAPIDS cuML distance metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
maximz committed Jan 6, 2022
1 parent 60353cc commit 2c30784
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def compute_neighbors_umap(
return knn_indices, knn_dists, forest


def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int):
def compute_neighbors_rapids(
X: np.ndarray, n_neighbors: int, metric: _Metric = 'euclidean'
):
"""Compute nearest neighbors using RAPIDS cuml.
Parameters
Expand All @@ -324,14 +326,17 @@ def compute_neighbors_rapids(X: np.ndarray, n_neighbors: int):
The data to compute nearest neighbors for.
n_neighbors
The number of neighbors to use.
metric
The metric to use to compute distances in high dimensional space.
This string must match a valid predefined metric in RAPIDS cuml.
Returns
-------
**knn_indices**, **knn_dists** : np.arrays of shape (n_observations, n_neighbors)
"""
from cuml.neighbors import NearestNeighbors

nn = NearestNeighbors(n_neighbors=n_neighbors)
nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric)
X_contiguous = np.ascontiguousarray(X, dtype=np.float32)
nn.fit(X_contiguous)
knn_distsq, knn_indices = nn.kneighbors(X_contiguous)
Expand Down Expand Up @@ -755,10 +760,6 @@ def compute_neighbors(
logg.warning(f'n_obs too small: adjusting to `n_neighbors = {n_neighbors}`')
if method == 'umap' and not knn:
raise ValueError('`method = \'umap\' only with `knn = True`.')
if method == 'rapids' and metric != 'euclidean':
raise ValueError(
"`method` 'rapids' only supports the 'euclidean' `metric`."
)
if method not in {'umap', 'gauss', 'rapids'}:
raise ValueError("`method` needs to be 'umap', 'gauss', or 'rapids'.")
if self._adata.shape[0] >= 10000 and not knn:
Expand All @@ -782,7 +783,9 @@ def compute_neighbors(
else:
self._distances = _distances
elif method == 'rapids':
knn_indices, knn_distances = compute_neighbors_rapids(X, n_neighbors)
knn_indices, knn_distances = compute_neighbors_rapids(
X, n_neighbors, metric=metric
)
else:
# non-euclidean case and approx nearest neighbors
if X.shape[0] < 4096:
Expand Down

0 comments on commit 2c30784

Please sign in to comment.