From d19f87d4faea0e19f0ce26c6fd9535430e1f9755 Mon Sep 17 00:00:00 2001 From: Marius Lange Date: Tue, 9 Jul 2024 15:51:06 +0200 Subject: [PATCH] Update tests --- src/moscot/problems/space/_mixins.py | 3 ++- tests/_utils.py | 18 +++++++++++++++--- tests/conftest.py | 2 +- tests/problems/space/test_mixins.py | 11 +++++++++-- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index e9c7f8af4..ce075c6f3 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -454,7 +454,7 @@ def correlate( # type: ignore[misc] # initialize a dict of group masks if groupby: groups = adata_sp.obs[groupby].cat.categories - group_masks = {group: adata_sp.obs[groupby] == group for group in groups} + group_masks = {group: (adata_sp.obs[groupby]).values == group for group in groups} corrs[key] = {} else: group_masks = {"all": np.ones(adata_sp.shape[0], dtype=bool)} @@ -472,6 +472,7 @@ def correlate( # type: ignore[misc] logger.debug(f"Skipping `group={group}` as it contains less then 2 samples.") continue + # import pdb; pdb.set_trace() corr_val = [ corr(gexp_pred_sp[group_mask, gi], gexp_sp[group_mask, gi])[0] for gi, _ in enumerate(var_sc) ] diff --git a/tests/_utils.py b/tests/_utils.py index 2cfefcafe..d2d1960b9 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -1,6 +1,7 @@ from typing import Any, List, Optional, Tuple, Type, Union import numpy as np +import pandas as pd from scipy.sparse import csr_matrix from anndata import AnnData @@ -46,10 +47,21 @@ def _ones(self, n: int) -> ArrayLike: return np.ones(n) -def _make_adata(grid: ArrayLike, n: int, seed) -> List[AnnData]: +def _make_adata(grid: ArrayLike, n: int, seed, cat_key: str = "covariate", num_categories: int = 3) -> List[AnnData]: rng = np.random.RandomState(seed) - X = rng.normal(size=(100, 60)) - return [AnnData(X=csr_matrix(X), obsm={"spatial": grid.copy()}) for _ in range(n)] + n_cells = 100 + X = rng.normal(size=(n_cells, 60)) + + # generate a categorical variable + categories = [f"cat_{i+1}" for i in range(num_categories)] + categorical_data = rng.choice(categories, size=n_cells) + + adatas = [] + for _ in range(n): + obs_df = pd.DataFrame({cat_key: pd.Categorical(categorical_data)}) + adatas.append(AnnData(X=csr_matrix(X), obs=obs_df, obsm={"spatial": grid.copy()})) + + return adatas def _adata_spatial_split(adata: AnnData) -> Tuple[AnnData, AnnData]: diff --git a/tests/conftest.py b/tests/conftest.py index ce26bfaab..95f13f5dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -185,7 +185,7 @@ def adata_space_rotate() -> AnnData: @pytest.fixture() def adata_mapping() -> AnnData: grid = _make_grid(10) - adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17) + adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17, cat_key="covariate", num_categories=3) sc.pp.pca(adataref, n_comps=30) return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-") diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index c105683a9..a6c1bc74e 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -127,18 +127,25 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo class TestSpatialMappingAnalysisMixin: @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) @pytest.mark.parametrize("var_names", ["0", [str(i) for i in range(20)]]) + @pytest.mark.parametrize("groupby", [None, "covariate"]) def test_analysis( self, adata_mapping: AnnData, sc_attr: Dict[str, str], var_names: Optional[List[Optional[str]]], + groupby: Optional[str], ): adataref, adatasp = _adata_spatial_split(adata_mapping) mp = MappingProblem(adataref, adatasp).prepare(batch_key="batch", sc_attr=sc_attr).solve() - corr = mp.correlate(var_names) + corr = mp.correlate(var_names, groupby=groupby) imp = mp.impute() - pd.testing.assert_series_equal(*list(corr.values())) + + if groupby: + for key in adata_mapping.obs[groupby].cat.categories: + pd.testing.assert_series_equal(*[corr[problem][key] for problem in corr]) + else: + pd.testing.assert_series_equal(*list(corr.values())) assert imp.shape == adatasp.shape def test_correspondence(