Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sanitize to score_genes #3262

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 21 additions & 13 deletions src/scanpy/tools/_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def score_genes(
adata: AnnData,
gene_list: Sequence[str] | pd.Index[str],
*,
ctrl_as_ref: bool = True,
sanitize: bool = False,
ctrl_size: int = 50,
gene_pool: Sequence[str] | pd.Index[str] | None = None,
n_bins: int = 25,
Expand All @@ -91,9 +91,10 @@ def score_genes(
The annotated data matrix.
gene_list
The list of gene names used for score calculation.
ctrl_as_ref
Allow the algorithm to use the control genes as reference.
Will be changed to `False` in scanpy 2.0.
sanitize
Ensure that bins are even-sized,
and disallow the use of control genes as reference.
Will be changed to `True` in scanpy 2.0.
ctrl_size
Number of reference genes to be sampled from each bin. If `len(gene_list)` is not too
low, you can set `ctrl_size=len(gene_list)`.
Expand Down Expand Up @@ -150,7 +151,7 @@ def score_genes(
for r_genes in _score_genes_bins(
gene_list,
gene_pool,
ctrl_as_ref=ctrl_as_ref,
sanitize=sanitize,
ctrl_size=ctrl_size,
n_bins=n_bins,
get_subset=get_subset,
Expand All @@ -159,8 +160,8 @@ def score_genes(

if len(control_genes) == 0:
msg = "No control genes found in any cut."
if ctrl_as_ref:
msg += " Try setting `ctrl_as_ref=False`."
if not sanitize:
msg += " Try setting `sanitize=True`."
raise RuntimeError(msg)

means_list, means_control = (
Expand Down Expand Up @@ -227,7 +228,7 @@ def _score_genes_bins(
gene_list: pd.Index[str],
gene_pool: pd.Index[str],
*,
ctrl_as_ref: bool,
sanitize: bool,
ctrl_size: int,
n_bins: int,
get_subset: _GetSubset,
Expand All @@ -237,13 +238,20 @@ def _score_genes_bins(
# Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
obs_avg = obs_avg[np.isfinite(obs_avg)]

n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
obs_cut = obs_avg.rank(method="min") // n_items
keep_ctrl_in_obs_cut = False if ctrl_as_ref else obs_cut.index.isin(gene_list)
if sanitize:
obs_avg.sort_values(ascending=True, inplace=True)
n_items = int(np.ceil(len(obs_avg) / (n_bins)))
rank = np.repeat(np.arange(n_bins), n_items)[: len(obs_avg)]
obs_cut = pd.Series(rank, index=obs_avg.index)
keep_ctrl_in_obs_cut = ~obs_cut.index.isin(gene_list)
else:
n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
obs_cut = obs_avg.rank(method="min") // n_items
keep_ctrl_in_obs_cut = True

# now pick `ctrl_size` genes from every cut
for cut in np.unique(obs_cut.loc[gene_list]):
r_genes: pd.Index[str] = obs_cut[(obs_cut == cut) & ~keep_ctrl_in_obs_cut].index
r_genes: pd.Index[str] = obs_cut[(obs_cut == cut) & keep_ctrl_in_obs_cut].index
if len(r_genes) == 0:
msg = (
f"No control genes for {cut=}. You might want to increase "
Expand All @@ -252,7 +260,7 @@ def _score_genes_bins(
logg.warning(msg)
if ctrl_size < len(r_genes):
r_genes = r_genes.to_series().sample(ctrl_size).index
if ctrl_as_ref: # otherwise `r_genes` is already filtered
if not sanitize: # otherwise `r_genes` is already filtered
r_genes = r_genes.difference(gene_list)
yield r_genes

Expand Down
38 changes: 30 additions & 8 deletions tests/test_score_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from scipy.sparse import csr_matrix

import scanpy as sc
from scanpy.tools._score_genes import _check_score_genes_args, _score_genes_bins
from testing.scanpy._helpers.data import paul15

if TYPE_CHECKING:
Expand Down Expand Up @@ -275,18 +276,39 @@ def test_no_control_gene():
sc.tl.score_genes(adata, adata.var_names[:1], ctrl_size=1)


@pytest.mark.parametrize(
"ctrl_as_ref", [True, False], ids=["ctrl_as_ref", "no_ctrl_as_ref"]
)
def test_gene_list_is_control(*, ctrl_as_ref: bool):
@pytest.mark.parametrize("sanitize", [True, False], ids=["sanitize", "no_sanitize"])
def test_gene_list_is_control(*, sanitize: bool):
np.random.seed(0)
adata = sc.datasets.blobs(n_variables=10, n_observations=100, n_centers=20)
adata.var_names = "g" + adata.var_names
with (
pytest.raises(RuntimeError, match=r"No control genes found in any cut")
if ctrl_as_ref
else nullcontext()
nullcontext()
if sanitize
else pytest.raises(RuntimeError, match=r"No control genes found in any cut")
):
sc.tl.score_genes(
adata, gene_list="g3", ctrl_size=1, n_bins=5, ctrl_as_ref=ctrl_as_ref
adata, gene_list="g3", ctrl_size=1, n_bins=5, sanitize=sanitize
)


@pytest.mark.parametrize("sanitize", [True, False], ids=["sanitize", "no_sanitize"])
def test_score_genes(*, sanitize: bool):
adata = AnnData(TODO) # noqa: F821
gene_list = adata.var_names
gene_pool = None
gene_list, gene_pool, get_subset = _check_score_genes_args(
adata, gene_list, gene_pool, use_raw=False, layer=None
)

bins = list(
_score_genes_bins(
gene_list,
gene_pool,
sanitize=sanitize,
ctrl_size=50,
n_bins=25,
get_subset=get_subset,
)
)

assert sanitize == (0 not in map(len, bins))
Loading