Skip to content

Commit

Permalink
Add anndata to gpu for dask (#312)
Browse files Browse the repository at this point in the history
* update anndata to XPU for dask

* update test
  • Loading branch information
Intron7 authored Dec 19, 2024
1 parent c6861aa commit 8e3fef1
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ markers = [
[tool.hatch.build]
# exclude big files that don’t need to be installed
exclude = [
"src/rapids_singlecell/_testing.py",
"tests",
"docs",
"notebooks"
Expand Down
23 changes: 23 additions & 0 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import cupy as cp
import numpy as np
from cupyx.scipy.sparse import csr_matrix
from dask.array import Array as DaskArray # noqa: F401
from scipy.sparse import csc_matrix as csc_matrix_cpu
from scipy.sparse import csr_matrix as csr_matrix_cpu


def _meta_dense(dtype):
Expand All @@ -11,3 +14,23 @@ def _meta_dense(dtype):

def _meta_sparse(dtype):
return csr_matrix(cp.array((1.0,), dtype=dtype))


def _meta_dense(dtype):
return cp.zeros([0], dtype=dtype)


def _meta_sparse(dtype):
return csr_matrix(cp.array((1.0,), dtype=dtype))


def _meta_dense_cpu(dtype):
return np.zeros([0], dtype=dtype)


def _meta_sparse_csr_cpu(dtype):
return csr_matrix_cpu(np.array((1.0,), dtype=dtype))


def _meta_sparse_csc_cpu(dtype):
return csc_matrix_cpu(np.array((1.0,), dtype=dtype))
22 changes: 22 additions & 0 deletions src/rapids_singlecell/get/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@
import numpy as np
from cupyx.scipy.sparse import csc_matrix as csc_matrix_gpu
from cupyx.scipy.sparse import csr_matrix as csr_matrix_gpu
from dask.array import Array as DaskArray
from scanpy.get import _get_obs_rep, _set_obs_rep
from scipy.sparse import csc_matrix as csc_matrix_cpu
from scipy.sparse import csr_matrix as csr_matrix_cpu
from scipy.sparse import isspmatrix_csc as isspmatrix_csc_cpu
from scipy.sparse import isspmatrix_csr as isspmatrix_csr_cpu

from rapids_singlecell._compat import (
_meta_dense,
_meta_dense_cpu,
_meta_sparse,
_meta_sparse_csc_cpu,
_meta_sparse_csr_cpu,
)

if TYPE_CHECKING:
from anndata import AnnData

Expand Down Expand Up @@ -79,6 +88,11 @@ def X_to_GPU(X: CPU_ARRAY_TYPE, warning: str = "X") -> GPU_ARRAY_TYPE:
"""
if isinstance(X, GPU_ARRAY_TYPE):
pass
elif isinstance(X, DaskArray):
if isinstance(X._meta, csc_matrix_cpu):
pass
meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
elif isspmatrix_csr_cpu(X):
X = csr_matrix_gpu(X)
elif isspmatrix_csc_cpu(X):
Expand Down Expand Up @@ -146,6 +160,14 @@ def X_to_CPU(X: GPU_ARRAY_TYPE) -> CPU_ARRAY_TYPE:
X
Matrix or array to transfer to the host memory
"""
if isinstance(X, DaskArray):
if isinstance(X._meta, csr_matrix_gpu):
meta = _meta_sparse_csr_cpu
elif isinstance(X._meta, csc_matrix_gpu):
meta = _meta_sparse_csc_cpu
else:
meta = _meta_dense_cpu
X = X.map_blocks(X_to_GPU, meta=meta(X.dtype))
if isinstance(X, GPU_ARRAY_TYPE):
X = X.get()
else:
Expand Down
59 changes: 59 additions & 0 deletions tests/dask/test_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import cupy as cp
import numpy as np
import pytest
from scanpy.datasets import pbmc3k_processed
from scipy import sparse

import rapids_singlecell as rsc
from rapids_singlecell._testing import (
as_dense_cupy_dask_array,
as_sparse_cupy_dask_array,
)


@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
def test_get_anndata(client, data_kind):
adata = pbmc3k_processed()
dask_adata = adata.copy()
if data_kind == "sparse":
adata.X = rsc.get.X_to_GPU(sparse.csr_matrix(adata.X.astype(np.float64)))
dask_adata.X = as_sparse_cupy_dask_array(dask_adata.X.astype(np.float64))
elif data_kind == "dense":
adata.X = cp.array(adata.X.astype(np.float64))
dask_adata.X = as_dense_cupy_dask_array(dask_adata.X.astype(np.float64))
else:
raise ValueError(f"Unknown data_kind {data_kind}")

assert type(adata.X) is type(dask_adata.X._meta)

if data_kind == "sparse":
cp.testing.assert_array_equal(
adata.X.toarray(), dask_adata.X.compute().toarray()
)
else:
cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())

rsc.get.anndata_to_CPU(dask_adata)
rsc.get.anndata_to_CPU(adata)

assert type(adata.X) is type(dask_adata.X._meta)

if data_kind == "sparse":
cp.testing.assert_array_equal(
adata.X.toarray(), dask_adata.X.compute().toarray()
)
else:
cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())
rsc.get.anndata_to_GPU(dask_adata)
rsc.get.anndata_to_GPU(adata)

assert type(adata.X) is type(dask_adata.X._meta)

if data_kind == "sparse":
cp.testing.assert_array_equal(
adata.X.toarray(), dask_adata.X.compute().toarray()
)
else:
cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())

0 comments on commit 8e3fef1

Please sign in to comment.