diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 4f225da2a1..2d4dc9c2be 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -56,6 +56,7 @@ def umap( key_added: str | None = None, neighbors_key: str = "neighbors", copy: bool = False, + parallel: bool = False, ) -> AnnData | None: """\ Embed the neighborhood graph using UMAP :cite:p:`McInnes2018`. @@ -146,7 +147,8 @@ def umap( :attr:`~anndata.AnnData.obsp`\\ ``[.uns[neighbors_key]['connectivities_key']]`` for connectivities. copy Return a copy instead of writing to adata. - + parallel + Whether to run the computation using numba parallel. Running in parallel is non-deterministic. Returns ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: @@ -214,6 +216,12 @@ def umap( # for the init condition in the UMAP embedding default_epochs = 500 if neighbors["connectivities"].shape[0] <= 10000 else 200 n_epochs = default_epochs if maxiter is None else maxiter + if parallel and random_state is not None: + warnings.warn( + "Parallel execution was expected to be disabled when both `parallel=True` and `random_state` are set, " + "to ensure reproducibility. However, parallel execution still seems to occur, which may lead to " + "non-deterministic results." + ) X_umap, _ = simplicial_set_embedding( data=X, graph=neighbors["connectivities"].tocoo(), @@ -232,6 +240,7 @@ def umap( densmap_kwds={}, output_dens=False, verbose=settings.verbosity > 3, + parallel=parallel, ) elif method == "rapids": msg = ( diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 692157a084..c4377535c2 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + import numpy as np import pytest from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_raises @@ -88,3 +90,28 @@ def test_diffmap(): sc.tl.diffmap(pbmc, random_state=1234) d3 = pbmc.obsm["X_diffmap"].copy() assert_raises(AssertionError, assert_array_equal, d1, d3) + + +@pytest.mark.parametrize( + ("random_state", "expect_warning"), + [ + pytest.param(42, True, id="random_state_int"), + pytest.param(np.random.RandomState(42), True, id="random_state_RandomState"), + pytest.param(None, True, id="random_state_None"), + ], +) +def test_umap_parallel_randomstate(random_state, expect_warning): + pbmc = pbmc68k_reduced()[:100, :].copy() + + if expect_warning: + with pytest.warns( + UserWarning, match="Parallel execution was expected to be disabled" + ): + sc.tl.umap(pbmc, parallel=True, random_state=random_state) + else: + # This is case is currently not in use because of lmcinnes/umap#1155 + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always") + sc.tl.umap(pbmc, parallel=True, random_state=random_state) + + assert len(record) == 0