From 8e3fef152aa312b910ebfc8196467a512ccd32a5 Mon Sep 17 00:00:00 2001 From: Severin Dicks <37635888+Intron7@users.noreply.github.com> Date: Thu, 19 Dec 2024 17:30:18 +0100 Subject: [PATCH] Add anndata to gpu for dask (#312) * update anndata to XPU for dask * update test --- pyproject.toml | 1 + src/rapids_singlecell/_compat.py | 23 +++++++++++ src/rapids_singlecell/get/_anndata.py | 22 ++++++++++ tests/dask/test_get.py | 59 +++++++++++++++++++++++++++ 4 files changed, 105 insertions(+) create mode 100644 tests/dask/test_get.py diff --git a/pyproject.toml b/pyproject.toml index f9425f95..5d41bb83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,6 +103,7 @@ markers = [ [tool.hatch.build] # exclude big files that don’t need to be installed exclude = [ + "src/rapids_singlecell/_testing.py", "tests", "docs", "notebooks" diff --git a/src/rapids_singlecell/_compat.py b/src/rapids_singlecell/_compat.py index f569504b..8c6f3ca6 100644 --- a/src/rapids_singlecell/_compat.py +++ b/src/rapids_singlecell/_compat.py @@ -1,8 +1,11 @@ from __future__ import annotations import cupy as cp +import numpy as np from cupyx.scipy.sparse import csr_matrix from dask.array import Array as DaskArray # noqa: F401 +from scipy.sparse import csc_matrix as csc_matrix_cpu +from scipy.sparse import csr_matrix as csr_matrix_cpu def _meta_dense(dtype): @@ -11,3 +14,23 @@ def _meta_dense(dtype): def _meta_sparse(dtype): return csr_matrix(cp.array((1.0,), dtype=dtype)) + + +def _meta_dense(dtype): + return cp.zeros([0], dtype=dtype) + + +def _meta_sparse(dtype): + return csr_matrix(cp.array((1.0,), dtype=dtype)) + + +def _meta_dense_cpu(dtype): + return np.zeros([0], dtype=dtype) + + +def _meta_sparse_csr_cpu(dtype): + return csr_matrix_cpu(np.array((1.0,), dtype=dtype)) + + +def _meta_sparse_csc_cpu(dtype): + return csc_matrix_cpu(np.array((1.0,), dtype=dtype)) diff --git a/src/rapids_singlecell/get/_anndata.py b/src/rapids_singlecell/get/_anndata.py index c42b52bf..b4d25005 100644 --- a/src/rapids_singlecell/get/_anndata.py +++ b/src/rapids_singlecell/get/_anndata.py @@ -7,12 +7,21 @@ import numpy as np from cupyx.scipy.sparse import csc_matrix as csc_matrix_gpu from cupyx.scipy.sparse import csr_matrix as csr_matrix_gpu +from dask.array import Array as DaskArray from scanpy.get import _get_obs_rep, _set_obs_rep from scipy.sparse import csc_matrix as csc_matrix_cpu from scipy.sparse import csr_matrix as csr_matrix_cpu from scipy.sparse import isspmatrix_csc as isspmatrix_csc_cpu from scipy.sparse import isspmatrix_csr as isspmatrix_csr_cpu +from rapids_singlecell._compat import ( + _meta_dense, + _meta_dense_cpu, + _meta_sparse, + _meta_sparse_csc_cpu, + _meta_sparse_csr_cpu, +) + if TYPE_CHECKING: from anndata import AnnData @@ -79,6 +88,11 @@ def X_to_GPU(X: CPU_ARRAY_TYPE, warning: str = "X") -> GPU_ARRAY_TYPE: """ if isinstance(X, GPU_ARRAY_TYPE): pass + elif isinstance(X, DaskArray): + if isinstance(X._meta, csc_matrix_cpu): + pass + meta = _meta_sparse if isinstance(X._meta, csr_matrix_cpu) else _meta_dense + X = X.map_blocks(X_to_GPU, meta=meta(X.dtype)) elif isspmatrix_csr_cpu(X): X = csr_matrix_gpu(X) elif isspmatrix_csc_cpu(X): @@ -146,6 +160,14 @@ def X_to_CPU(X: GPU_ARRAY_TYPE) -> CPU_ARRAY_TYPE: X Matrix or array to transfer to the host memory """ + if isinstance(X, DaskArray): + if isinstance(X._meta, csr_matrix_gpu): + meta = _meta_sparse_csr_cpu + elif isinstance(X._meta, csc_matrix_gpu): + meta = _meta_sparse_csc_cpu + else: + meta = _meta_dense_cpu + X = X.map_blocks(X_to_GPU, meta=meta(X.dtype)) if isinstance(X, GPU_ARRAY_TYPE): X = X.get() else: diff --git a/tests/dask/test_get.py b/tests/dask/test_get.py new file mode 100644 index 00000000..e6594fa7 --- /dev/null +++ b/tests/dask/test_get.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import cupy as cp +import numpy as np +import pytest +from scanpy.datasets import pbmc3k_processed +from scipy import sparse + +import rapids_singlecell as rsc +from rapids_singlecell._testing import ( + as_dense_cupy_dask_array, + as_sparse_cupy_dask_array, +) + + +@pytest.mark.parametrize("data_kind", ["sparse", "dense"]) +def test_get_anndata(client, data_kind): + adata = pbmc3k_processed() + dask_adata = adata.copy() + if data_kind == "sparse": + adata.X = rsc.get.X_to_GPU(sparse.csr_matrix(adata.X.astype(np.float64))) + dask_adata.X = as_sparse_cupy_dask_array(dask_adata.X.astype(np.float64)) + elif data_kind == "dense": + adata.X = cp.array(adata.X.astype(np.float64)) + dask_adata.X = as_dense_cupy_dask_array(dask_adata.X.astype(np.float64)) + else: + raise ValueError(f"Unknown data_kind {data_kind}") + + assert type(adata.X) is type(dask_adata.X._meta) + + if data_kind == "sparse": + cp.testing.assert_array_equal( + adata.X.toarray(), dask_adata.X.compute().toarray() + ) + else: + cp.testing.assert_array_equal(adata.X, dask_adata.X.compute()) + + rsc.get.anndata_to_CPU(dask_adata) + rsc.get.anndata_to_CPU(adata) + + assert type(adata.X) is type(dask_adata.X._meta) + + if data_kind == "sparse": + cp.testing.assert_array_equal( + adata.X.toarray(), dask_adata.X.compute().toarray() + ) + else: + cp.testing.assert_array_equal(adata.X, dask_adata.X.compute()) + rsc.get.anndata_to_GPU(dask_adata) + rsc.get.anndata_to_GPU(adata) + + assert type(adata.X) is type(dask_adata.X._meta) + + if data_kind == "sparse": + cp.testing.assert_array_equal( + adata.X.toarray(), dask_adata.X.compute().toarray() + ) + else: + cp.testing.assert_array_equal(adata.X, dask_adata.X.compute())