From d09d0d12c2193ba36d928a744fbfadf0bfcd82c0 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 26 Sep 2024 10:47:03 +0200 Subject: [PATCH] WIP --- src/scanpy/tools/_score_genes.py | 34 +++++++++++++++++----------- tests/test_score_genes.py | 38 +++++++++++++++++++++++++------- 2 files changed, 51 insertions(+), 21 deletions(-) diff --git a/src/scanpy/tools/_score_genes.py b/src/scanpy/tools/_score_genes.py index a3909b7a28..c9bde464df 100644 --- a/src/scanpy/tools/_score_genes.py +++ b/src/scanpy/tools/_score_genes.py @@ -65,7 +65,7 @@ def score_genes( adata: AnnData, gene_list: Sequence[str] | pd.Index[str], *, - ctrl_as_ref: bool = True, + sanitize: bool = False, ctrl_size: int = 50, gene_pool: Sequence[str] | pd.Index[str] | None = None, n_bins: int = 25, @@ -91,9 +91,10 @@ def score_genes( The annotated data matrix. gene_list The list of gene names used for score calculation. - ctrl_as_ref - Allow the algorithm to use the control genes as reference. - Will be changed to `False` in scanpy 2.0. + sanitize + Ensure that bins are even-sized, + and disallow the use of control genes as reference. + Will be changed to `True` in scanpy 2.0. ctrl_size Number of reference genes to be sampled from each bin. If `len(gene_list)` is not too low, you can set `ctrl_size=len(gene_list)`. @@ -150,7 +151,7 @@ def score_genes( for r_genes in _score_genes_bins( gene_list, gene_pool, - ctrl_as_ref=ctrl_as_ref, + sanitize=sanitize, ctrl_size=ctrl_size, n_bins=n_bins, get_subset=get_subset, @@ -159,8 +160,8 @@ def score_genes( if len(control_genes) == 0: msg = "No control genes found in any cut." - if ctrl_as_ref: - msg += " Try setting `ctrl_as_ref=False`." + if not sanitize: + msg += " Try setting `sanitize=True`." raise RuntimeError(msg) means_list, means_control = ( @@ -227,7 +228,7 @@ def _score_genes_bins( gene_list: pd.Index[str], gene_pool: pd.Index[str], *, - ctrl_as_ref: bool, + sanitize: bool, ctrl_size: int, n_bins: int, get_subset: _GetSubset, @@ -237,13 +238,20 @@ def _score_genes_bins( # Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries obs_avg = obs_avg[np.isfinite(obs_avg)] - n_items = int(np.round(len(obs_avg) / (n_bins - 1))) - obs_cut = obs_avg.rank(method="min") // n_items - keep_ctrl_in_obs_cut = False if ctrl_as_ref else obs_cut.index.isin(gene_list) + if sanitize: + obs_avg.sort_values(ascending=True, inplace=True) + n_items = int(np.ceil(len(obs_avg) / (n_bins))) + rank = np.repeat(np.arange(n_bins), n_items)[: len(obs_avg)] + obs_cut = pd.Series(rank, index=obs_avg.index) + keep_ctrl_in_obs_cut = ~obs_cut.index.isin(gene_list) + else: + n_items = int(np.round(len(obs_avg) / (n_bins - 1))) + obs_cut = obs_avg.rank(method="min") // n_items + keep_ctrl_in_obs_cut = True # now pick `ctrl_size` genes from every cut for cut in np.unique(obs_cut.loc[gene_list]): - r_genes: pd.Index[str] = obs_cut[(obs_cut == cut) & ~keep_ctrl_in_obs_cut].index + r_genes: pd.Index[str] = obs_cut[(obs_cut == cut) & keep_ctrl_in_obs_cut].index if len(r_genes) == 0: msg = ( f"No control genes for {cut=}. You might want to increase " @@ -252,7 +260,7 @@ def _score_genes_bins( logg.warning(msg) if ctrl_size < len(r_genes): r_genes = r_genes.to_series().sample(ctrl_size).index - if ctrl_as_ref: # otherwise `r_genes` is already filtered + if not sanitize: # otherwise `r_genes` is already filtered r_genes = r_genes.difference(gene_list) yield r_genes diff --git a/tests/test_score_genes.py b/tests/test_score_genes.py index 4ac1b62224..30d36a53bd 100644 --- a/tests/test_score_genes.py +++ b/tests/test_score_genes.py @@ -11,6 +11,7 @@ from scipy.sparse import csr_matrix import scanpy as sc +from scanpy.tools._score_genes import _check_score_genes_args, _score_genes_bins from testing.scanpy._helpers.data import paul15 if TYPE_CHECKING: @@ -275,18 +276,39 @@ def test_no_control_gene(): sc.tl.score_genes(adata, adata.var_names[:1], ctrl_size=1) -@pytest.mark.parametrize( - "ctrl_as_ref", [True, False], ids=["ctrl_as_ref", "no_ctrl_as_ref"] -) -def test_gene_list_is_control(*, ctrl_as_ref: bool): +@pytest.mark.parametrize("sanitize", [True, False], ids=["sanitize", "no_sanitize"]) +def test_gene_list_is_control(*, sanitize: bool): np.random.seed(0) adata = sc.datasets.blobs(n_variables=10, n_observations=100, n_centers=20) adata.var_names = "g" + adata.var_names with ( - pytest.raises(RuntimeError, match=r"No control genes found in any cut") - if ctrl_as_ref - else nullcontext() + nullcontext() + if sanitize + else pytest.raises(RuntimeError, match=r"No control genes found in any cut") ): sc.tl.score_genes( - adata, gene_list="g3", ctrl_size=1, n_bins=5, ctrl_as_ref=ctrl_as_ref + adata, gene_list="g3", ctrl_size=1, n_bins=5, sanitize=sanitize ) + + +@pytest.mark.parametrize("sanitize", [True, False], ids=["sanitize", "no_sanitize"]) +def test_score_genes(*, sanitize: bool): + adata = AnnData(TODO) # noqa: F821 + gene_list = adata.var_names + gene_pool = None + gene_list, gene_pool, get_subset = _check_score_genes_args( + adata, gene_list, gene_pool, use_raw=False, layer=None + ) + + bins = list( + _score_genes_bins( + gene_list, + gene_pool, + sanitize=sanitize, + ctrl_size=50, + n_bins=25, + get_subset=get_subset, + ) + ) + + assert sanitize == (0 not in map(len, bins))