Skip to content

Commit

Permalink
Multi-GPU support with dask (#179)
Browse files Browse the repository at this point in the history
* add first functions

* add hvg part1

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reset to main for hvg

* add support for hvg

* first pass pca

* pca update

* fix bug with csc matrix

* add dask to docs

* add tests

* update names

* get docs to work

* remove client from sparse calc

* need dask for docs

* add scale

* int64 updates

* For main branch

* test docs

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* fix rebase

* (fix): use `to_delayed` and `from_delayed` to submit gram matrix jobs  (#210)

* (fix): use `to_delayed` and `from_delayed` to submit gram matrix jobs

* (refactor): use `map_blocks`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): correct first dimension size

* (fix): add `x` as arg

* (fix): `ncols` usages

* (fix): try mapping block from x

* (fix): matrix creation + cleaner `map_blocks`

* (fix): `len(blocks)` -> `num_blocks`

* (fix): don't need `dask.delayed` decorator

* (fix): try some debugging

* (fix): use `cp.sum`

* (fix): revert to `to_delayed`

* (refactor): use `n_cols`

* (fix): remove `client`

* (fix): `client` doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Severin Dicks <[email protected]>

* (fix): use `map_blocks` for job submission in `_get_target_sum_dask` + `_second_pass_qc` (#211)

* (fix): use `to_delayed` and `from_delayed` to submit gram matrix jobs

* (refactor): use `map_blocks`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): correct first dimension size

* (fix): change normalization sum

* (fix): add `x` as arg

* (fix): `ncols` usages

* (fix): try mapping block from x

* (fix): use `X` directly to `map_blocks`

* (fix): matrix creation + cleaner `map_blocks`

* (fix): `len(blocks)` -> `num_blocks`

* (fix): don't need `dask.delayed` decorator

* (fix): try some debugging

* (fix): use `cp.sum`

* (fix): revert to `to_delayed`

* (refactor): use `n_cols`

* (fix): remove `num_blocks`

* (fix): need to specify `drop_axis` for reduction

* (fix): `map_blocks` for `_second_pass_qc_dask`

* (fix): remove `client`

* (chore): remove `client`

* (chore): remove in tests

* (fix): `client` doc

* (fix): return client to test context

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Severin Dicks <[email protected]>

* (fix): remove `extract_partitions` from `mean`/`var` calculation (#221)

* (fix): use `to_delayed` and `from_delayed` to submit gram matrix jobs

* (refactor): use `map_blocks`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): correct first dimension size

* (fix): change normalization sum

* (fix): add `x` as arg

* (fix): `ncols` usages

* (fix): try mapping block from x

* (fix): use `X` directly to `map_blocks`

* (fix): matrix creation + cleaner `map_blocks`

* (fix): `len(blocks)` -> `num_blocks`

* (fix): don't need `dask.delayed` decorator

* (fix): try some debugging

* (fix): use `cp.sum`

* (fix): revert to `to_delayed`

* (refactor): use `n_cols`

* (fix): remove `num_blocks`

* (fix): need to specify `drop_axis` for reduction

* (fix): try splitting mean/var directly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): `x` to `X`

* (fix) correct import

* (fix): `delayed` decorator

* (fix): get axis right

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): ravel mean/var

* (fix): too many mistakes to count

* (feat): same for other axis

* (fix): resolve all small tolerance differences

* (fix): try splitting first

* (fix): `compute` once

* (fix): stack `mean`/`var`

* (fix): use `float64` for mean-var

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): remove unnecessary cast

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): add `dask` dep for major axis

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (refactor): use cleaner `zeros`

* (fix): revert other `rtol`

* (fix): remove last `extract_partitions`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): `map_blocks` for `_second_pass_qc_dask`

* (fix): remove `client`

* (fix): remove other instances

* (fix): remove `client`

* (chore): remove `client`

* (chore): remove in tests

* (fix): `client` doc

* (fix): remove more `client`

* (fix): return client to test context

* (chore): re-add client

* (fix): oops

* (fiX): oops x 2

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Severin Dicks <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove client from hvg

* remove client

* remove client from scale

* update to fast transform

* (fix): `normalize_total` -> `log1p` -> `pca` with sparse (#217)

* (fix): use `to_delayed` and `from_delayed` to submit gram matrix jobs

* (refactor): use `map_blocks`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): correct first dimension size

* (fix): change normalization sum

* (fix): add `x` as arg

* (fix): `ncols` usages

* (fix): try mapping block from x

* (fix): use `X` directly to `map_blocks`

* (fix): matrix creation + cleaner `map_blocks`

* (fix): `len(blocks)` -> `num_blocks`

* (fix): don't need `dask.delayed` decorator

* (fix): try some debugging

* (fix): use `cp.sum`

* (fix): revert to `to_delayed`

* (refactor): use `n_cols`

* (fix): remove `num_blocks`

* (fix): need to specify `drop_axis` for reduction

* (chore): add full pipline

* (chore): use `cusparse`

* (fix): use `scipy_sparse`

* (fix): initialization

* (fix): try splitting mean/var directly

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): `x` to `X`

* (fix) correct import

* (fix): `delayed` decorator

* (fix): get axis right

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): ravel mean/var

* (fix): too many mistakes to count

* (feat): same for other axis

* (fix): resolve all small tolerance differences

* (fix): try splitting first

* (fix): `compute` once

* (fix): stack `mean`/`var`

* (fix): use `float64` for mean-var

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): remove unnecessary cast

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): add `dask` dep for major axis

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (refactor): use cleaner `zeros`

* (fix): revert other `rtol`

* (fix): remove last `extract_partitions`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* (fix): `map_blocks` for `_second_pass_qc_dask`

* (fix): remove `client`

* (fix): remove other instances

* (fix): remove `client`

* (chore): remove `client`

* (chore): remove in tests

* (fix): `client` doc

* (fix): remove more `client`

* (fix): return client to test context

* (chore): re-add client

* (fix): oops

* (fiX): oops x 2

* (chore): add dense test, which also doesn't work?

* (fix): corect filtering

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Severin Dicks <[email protected]>

* fix taskgraph

* (feat): use `map_blocks` in gram matrix calculation and and `mean_var` (#230)

* (feat): use `map_blocks` instead of `to_delayed` for `gram_matrix`

* (feat): same for `mean/var`

* (fix): correct shape

* (fix): add new axis to `_mean_var_dense_dask` + remove `dask.delayed`

* (fix): chunks shape

* update pca

* use lambda

* remove unused kernel

* test removed kernel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add outside compute (#245)

* update utils for lazy compute

* update utils

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move test helpers

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update typing

* update normalize

* go back to lambda

* slim down tests

* run tests on rapids-24.08

* compress hvg tests

* remove .todelayed

* remove dask.delayed

* update qc

* Update src/rapids_singlecell/preprocessing/_pca.py

Co-authored-by: Philipp A. <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update src/rapids_singlecell/preprocessing/_scale.py

Co-authored-by: Philipp A. <[email protected]>

* add error

* update tree pca

* Update src/rapids_singlecell/preprocessing/_scale.py

Co-authored-by: Ilan Gold <[email protected]>

* add note

* dask import

* update qc names

* update

* update _check_gpu_X

* update docs

* docs update

* make sure dtype is correct PCA

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add update

* Update src/rapids_singlecell/preprocessing/_hvg.py

Co-authored-by: Philipp A. <[email protected]>

* add log1p wraper

* fix updating var with hvg multibatch

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: dicks1 <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ilan Gold <[email protected]>
Co-authored-by: Phil Schaf <[email protected]>
  • Loading branch information
5 people authored Dec 19, 2024
1 parent 01dd624 commit 4629c05
Show file tree
Hide file tree
Showing 26 changed files with 1,999 additions and 487 deletions.
12 changes: 11 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@

autosummary_generate = True
autodoc_member_order = "bysource"
autodoc_mock_imports = ["cudf", "cuml", "cugraph", "cupy", "cupyx", "pylibraft", "cuvs"]
autodoc_mock_imports = [
"cudf",
"cuml",
"cugraph",
"cupy",
"cupyx",
"pylibraft",
"dask",
"cuvs",
]
default_role = "literal"
napoleon_google_docstring = False
napoleon_numpy_docstring = True
Expand Down Expand Up @@ -108,6 +117,7 @@
"rmm": ("https://docs.rapids.ai/api/rmm/stable/", None),
"statsmodels": ("https://www.statsmodels.org/stable/", None),
"omnipath": ("https://omnipath.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/stable/", None),
}

# List of patterns, relative to source directory, that match files and
Expand Down
5 changes: 4 additions & 1 deletion docs/release-notes/0.11.0.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
### 0.11.0 {small}`the-future`
### 0.10.11 {small}`2024-12-19`

```{rubric} Features
```
* Adds support for Multi-GPU out-of-core support through Dask {pr}`179` {smaller}`S Dicks, I Gold & P Angerer`
* use `cuvs` over `raft` for `pp.neighbors` for `rapids>=24.12`{pr}`304` {smaller}`S Dicks`
* switched to a different implementation of `pp.harmony_integrate` {pr}`308` {smaller}`S Dicks`
```{rubric} Performance
```

```{rubric} Bug fixes
```

```{rubric} Misc
```
* Update `get_random_state` for `scrublet `{pr}`301` {smaller}`S Dicks`
7 changes: 0 additions & 7 deletions docs/release-notes/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,27 +35,20 @@
## Version 0.9.0
```{include} /release-notes/0.9.6.md
```

```{include} /release-notes/0.9.5.md
```

```{include} /release-notes/0.9.4.md
```

```{include} /release-notes/0.9.3.md
```

```{include} /release-notes/0.9.2.md
```

```{include} /release-notes/0.9.1.md
```

```{include} /release-notes/0.9.0.md
```

## Version 0.8.0

```{include} /release-notes/0.8.1.md
```
```{include} /release-notes/0.8.0.md
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ doc = [
"scanpydoc[typehints,theme]>=0.9.4",
"readthedocs-sphinx-ext",
"sphinx_copybutton",
"dask",
"pytest",
]
test = [
Expand Down
13 changes: 13 additions & 0 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

import cupy as cp
from cupyx.scipy.sparse import csr_matrix
from dask.array import Array as DaskArray # noqa: F401


def _meta_dense(dtype):
return cp.zeros([0], dtype=dtype)


def _meta_sparse(dtype):
return csr_matrix(cp.array((1.0,), dtype=dtype))
18 changes: 17 additions & 1 deletion src/rapids_singlecell/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

from typing import TYPE_CHECKING

import cupy as cp
import pytest
from anndata.tests.helpers import asarray
from anndata.tests.helpers import as_dense_dask_array, as_sparse_dask_array, asarray
from cupyx.scipy import sparse as cusparse
from scipy import sparse

if TYPE_CHECKING:
Expand Down Expand Up @@ -38,3 +40,17 @@ def param_with(
ARRAY_TYPES_MEM = tuple(
at for (strg, _), ats in MAP_ARRAY_TYPES.items() if strg == "mem" for at in ats
)


def as_sparse_cupy_dask_array(X):
da = as_sparse_dask_array(X)
da = da.rechunk((da.shape[0] // 2, da.shape[1]))
da = da.map_blocks(cusparse.csr_matrix, dtype=X.dtype)
return da


def as_dense_cupy_dask_array(X):
X = as_dense_dask_array(X)
X = X.map_blocks(cp.array)
X = X.rechunk((X.shape[0] // 2, X.shape[1]))
return X
7 changes: 7 additions & 0 deletions src/rapids_singlecell/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from typing import TYPE_CHECKING, Union

import cupy as cp
import numpy as np
from cupyx.scipy.sparse import csc_matrix, csr_matrix
from dask.array import Array as DaskArray

AnyRandom = Union[int, np.random.RandomState, None] # noqa: UP007


ArrayTypes = Union[cp.ndarray, csc_matrix, csr_matrix] # noqa: UP007
ArrayTypesDask = Union[cp.ndarray, csc_matrix, csr_matrix, DaskArray] # noqa: UP007
2 changes: 1 addition & 1 deletion src/rapids_singlecell/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from ._neighbors import bbknn, neighbors
from ._normalize import log1p, normalize_pearson_residuals, normalize_total
from ._pca import pca
from ._qc import calculate_qc_metrics
from ._regress_out import regress_out
from ._scale import scale
from ._scrublet import scrublet, scrublet_simulate_doublets
from ._simple import (
calculate_qc_metrics,
filter_cells,
filter_genes,
filter_highly_variable,
Expand Down
53 changes: 39 additions & 14 deletions src/rapids_singlecell/preprocessing/_hvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import cupy as cp
import numpy as np
import pandas as pd
from cupyx.scipy.sparse import issparse, isspmatrix_csc
from cupyx.scipy.sparse import csr_matrix, issparse, isspmatrix_csc
from scanpy.get import _get_obs_rep

from ._simple import calculate_qc_metrics
from rapids_singlecell._compat import DaskArray, _meta_dense, _meta_sparse

from ._qc import _basic_qc
from ._utils import _check_gpu_X, _check_nonnegative_integers, _get_mean_var

if TYPE_CHECKING:
Expand Down Expand Up @@ -188,7 +190,11 @@ def highly_variable_genes(

if batch_key is None:
df = _highly_variable_genes_single_batch(
adata, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor
adata,
layer=layer,
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
)
else:
df = _highly_variable_genes_batched(
Expand Down Expand Up @@ -260,6 +266,19 @@ def in_bounds(
)


def _hvg_expm1(X):
if isinstance(X, DaskArray):
meta = _meta_sparse if isinstance(X._meta, csr_matrix) else _meta_dense
X = X.map_blocks(_hvg_expm1, meta=meta(X.dtype))
else:
X = X.copy()
if issparse(X):
X = X.expm1()
else:
X = cp.expm1(X)
return X


def _highly_variable_genes_single_batch(
adata: AnnData,
*,
Expand All @@ -277,18 +296,18 @@ def _highly_variable_genes_single_batch(
`highly_variable`, `means`, `dispersions`, and `dispersions_norm`.
"""
X = _get_obs_rep(adata, layer=layer)

_check_gpu_X(X, allow_dask=True)
if hasattr(X, "_view_args"): # AnnData array view
# For compatibility with anndata<0.9
X = X.copy() # Doesn't actually copy memory, just removes View class wrapper
X = X.copy()

if flavor == "seurat":
X = X.copy()
if issparse(X):
X = X.expm1()
else:
X = cp.expm1(X)
X = _hvg_expm1(X)

mean, var = _get_mean_var(X, axis=0)
if isinstance(X, DaskArray):
import dask

mean, var = dask.compute(mean, var)
mean[mean == 0] = 1e-12
disp = var / mean
if flavor == "seurat": # logarithmized mean as in Seurat
Expand Down Expand Up @@ -415,12 +434,18 @@ def _highly_variable_genes_batched(
for batch in batches:
adata_subset = adata[adata.obs[batch_key] == batch]

calculate_qc_metrics(adata_subset, layer=layer)
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0
X = _get_obs_rep(adata_subset, layer=layer)
_check_gpu_X(X, allow_dask=True)
_, _, _, n_cells_per_gene = _basic_qc(X=X)
filt = (n_cells_per_gene > 0).get()
adata_subset = adata_subset[:, filt]

hvg = _highly_variable_genes_single_batch(
adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor
adata_subset,
layer=layer,
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
)
hvg.reset_index(drop=False, inplace=True, names=["gene"])

Expand Down
102 changes: 102 additions & 0 deletions src/rapids_singlecell/preprocessing/_kernels/_qc_kernels_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations

from cuml.common.kernel_utils import cuda_kernel_factory

_sparse_qc_kernel_csr_dask_cells = r"""
(const int *indptr,const int *index,const {0} *data,
{0}* sums_cells, int* cell_ex,
int n_cells) {
int cell = blockDim.x * blockIdx.x + threadIdx.x;
if(cell >= n_cells){
return;
}
int start_idx = indptr[cell];
int stop_idx = indptr[cell+1];
{0} sums_cells_i = 0;
int cell_ex_i = 0;
for(int gene = start_idx; gene < stop_idx; gene++){
{0} value = data[gene];
int gene_number = index[gene];
sums_cells_i += value;
cell_ex_i += 1;
}
sums_cells[cell] = sums_cells_i;
cell_ex[cell] = cell_ex_i;
}
"""


_sparse_qc_kernel_csr_dask_genes = r"""
(const int *index,const {0} *data,
{0}* sums_genes, int* gene_ex,
int nnz) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if(idx >= nnz){
return;
}
int minor_pos = index[idx];
atomicAdd(&sums_genes[minor_pos], data[idx]);
atomicAdd(&gene_ex[minor_pos], 1);
}
"""

_sparse_qc_kernel_dense_cells = r"""
(const {0} *data,
{0}* sums_cells, int* cell_ex,
int n_cells,int n_genes) {
int cell = blockDim.x * blockIdx.x + threadIdx.x;
int gene = blockDim.y * blockIdx.y + threadIdx.y;
if(cell >= n_cells || gene >=n_genes){
return;
}
long long int index = static_cast<long long int>(cell) * n_genes + gene;
{0} value = data[index];
if (value>0.0){
atomicAdd(&sums_cells[cell], value);
atomicAdd(&cell_ex[cell], 1);
}
}
"""

_sparse_qc_kernel_dense_genes = r"""
(const {0} *data,
{0}* sums_genes,int* gene_ex,
int n_cells,int n_genes) {
int cell = blockDim.x * blockIdx.x + threadIdx.x;
int gene = blockDim.y * blockIdx.y + threadIdx.y;
if(cell >= n_cells || gene >=n_genes){
return;
}
long long int index = static_cast<long long int>(cell) * n_genes + gene;
{0} value = data[index];
if (value>0.0){
atomicAdd(&sums_genes[gene], value);
atomicAdd(&gene_ex[gene], 1);
}
}
"""


def _sparse_qc_csr_dask_cells(dtype):
return cuda_kernel_factory(
_sparse_qc_kernel_csr_dask_cells, (dtype,), "_sparse_qc_kernel_csr_dask_cells"
)


def _sparse_qc_csr_dask_genes(dtype):
return cuda_kernel_factory(
_sparse_qc_kernel_csr_dask_genes, (dtype,), "_sparse_qc_kernel_csr_dask_genes"
)


def _sparse_qc_dense_cells(dtype):
return cuda_kernel_factory(
_sparse_qc_kernel_dense_cells, (dtype,), "_sparse_qc_kernel_dense_cells"
)


def _sparse_qc_dense_genes(dtype):
return cuda_kernel_factory(
_sparse_qc_kernel_dense_genes, (dtype,), "_sparse_qc_kernel_dense_genes"
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

_csr_scale_diff_kernel = r"""
(const int *indptr, const int *indices, {0} *data, const double * std, const int *mask, {0} clipper,int nrows) {
(const int *indptr, const int *indices, {0} *data, const {0} * std, const int *mask, {0} clipper,int nrows) {
int row = blockIdx.x;
if(row >= nrows){
Expand Down
Loading

0 comments on commit 4629c05

Please sign in to comment.