From 63de968cee40030bdfb7f5d9deb04639d0fa37f6 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 13 Mar 2024 23:19:43 +0100 Subject: [PATCH] added some tests and refactored a line --- tests/_utils.py | 14 ++++++++ tests/conftest.py | 36 ++++++++++--------- tests/problems/conftest.py | 12 +++++++ tests/problems/generic/test_fgw_problem.py | 16 +++++++++ tests/problems/generic/test_gw_problem.py | 15 +++++++- .../problems/generic/test_sinkhorn_problem.py | 15 +++++--- tests/problems/time/test_lineage_problem.py | 3 +- 7 files changed, 87 insertions(+), 24 deletions(-) diff --git a/tests/_utils.py b/tests/_utils.py index c6ae8c54e..2cfefcafe 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -67,6 +67,20 @@ def _make_grid(grid_size: int) -> ArrayLike: return np.vstack([X1.ravel(), X2.ravel()]).T +def _assert_marginals_set(adata_time, problem, key, marginal_keys): + """Helper function to check if marginals are set correctly""" + adata_time0 = adata_time[key[0] == adata_time.obs["time"]] + adata_time1 = adata_time[key[1] == adata_time.obs["time"]] + if marginal_keys[0] is not None: # check if marginal keys are set + a = adata_time0.obs[marginal_keys[0]].values + b = adata_time1.obs[marginal_keys[1]].values + assert np.allclose(problem[key].a, a) + assert np.allclose(problem[key].b, b) + else: # otherwise check if marginals are uniform + assert np.allclose(problem[key].a, 1.0 / adata_time0.shape[0]) + assert np.allclose(problem[key].b, 1.0 / adata_time1.shape[0]) + + class Problem(CompoundProblem[Any, OTProblem]): @property def _base_problem_type(self) -> Type[B]: diff --git a/tests/conftest.py b/tests/conftest.py index 20a267524..df755069c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -119,15 +119,32 @@ def adata_y(y: Geom_t) -> AnnData: return AnnData(X=np.asarray(y, dtype=float), obsm={"X_pca": pc}) +def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t: + rng = np.random.RandomState(seed) + a = np.ones((n,)) if uniform else np.abs(rng.normal(size=(n,))) + a /= np.sum(a) + return jnp.asarray(a) + + @pytest.fixture() def adata_time() -> AnnData: rng = np.random.RandomState(42) - adatas = [AnnData(X=csr_matrix(rng.normal(size=(96, 60)))) for _ in range(3)] + + adatas = [ + AnnData( + X=csr_matrix(rng.normal(size=(96, 60))), + obs={ + "left_marginals_balanced": creat_prob(96, seed=42), + "right_marginals_balanced": creat_prob(96, seed=42), + }, + ) + for _ in range(3) + ] adata = ad.concat(adatas, label="time", index_unique="-") adata.obs["time"] = pd.to_numeric(adata.obs["time"]).astype("category") adata.obs["batch"] = rng.choice((0, 1, 2), len(adata)) - adata.obs["left_marginals"] = np.ones(len(adata)) - adata.obs["right_marginals"] = np.ones(len(adata)) + adata.obs["left_marginals_unbalanced"] = np.ones(len(adata)) + adata.obs["right_marginals_unbalanced"] = np.ones(len(adata)) adata.obs["celltype"] = rng.choice(["A", "B", "C"], size=len(adata)) # genes from mouse/human proliferation/apoptosis genes = ["ANLN", "ANP32E", "ATAD2", "Mcm4", "Smc4", "Gtse1", "ADD1", "AIFM3", "ANKH", "Ercc5", "Serpinb5", "Inhbb"] @@ -139,19 +156,6 @@ def adata_time() -> AnnData: return adata -def create_marginals(n: int, m: int, *, uniform: bool = False, seed: Optional[int] = None) -> Geom_t: - rng = np.random.RandomState(seed) - if uniform: - a, b = np.ones((n,)), np.ones((m,)) - else: - a = np.abs(rng.normal(size=(n,))) - b = np.abs(rng.normal(size=(m,))) - a /= np.sum(a) - b /= np.sum(b) - - return jnp.asarray(a), jnp.asarray(b) - - @pytest.fixture() def gt_temporal_adata() -> AnnData: adata = _gt_temporal_adata.copy() diff --git a/tests/problems/conftest.py b/tests/problems/conftest.py index d5d876e25..fd2c76c4d 100644 --- a/tests/problems/conftest.py +++ b/tests/problems/conftest.py @@ -58,6 +58,18 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData: return adata +# keys for marginals +@pytest.fixture( + params=[ + (None, None), + ("left_marginals_balanced", "right_marginals_balanced"), + ], + ids=["default", "balanced"], +) +def marginal_keys(request): + return request.param + + sinkhorn_args_1 = { "epsilon": 0.7, "tau_a": 1.0, diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index f3488f2c0..6193f2683 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -24,6 +24,7 @@ from moscot.base.output import BaseSolverOutput from moscot.base.problems import OTProblem from moscot.problems.generic import FGWProblem +from tests._utils import _assert_marginals_set from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -67,6 +68,21 @@ def test_prepare(self, adata_space_rotate: AnnData, policy): assert key in expected_keys[policy] assert isinstance(problem[key], OTProblem) + @pytest.mark.fast() + def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): + problem = FGWProblem(adata=adata_time) + problem = problem.prepare( + key="time", + policy="sequential", + joint_attr="X_pca", + x_attr="X_pca", + y_attr="X_pca", + a=marginal_keys[0], + b=marginal_keys[1], + ) + for key in problem: + _assert_marginals_set(adata_time, problem, key, marginal_keys) + def test_solve_balanced(self, adata_space_rotate: AnnData): eps = 0.5 adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 924150773..07a8dab14 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -23,6 +23,7 @@ from moscot.base.output import BaseSolverOutput from moscot.base.problems import OTProblem from moscot.problems.generic import GWProblem +from tests._utils import _assert_marginals_set from tests.problems.conftest import ( geometry_args, gw_args_1, @@ -37,7 +38,10 @@ class TestGWProblem: @pytest.mark.fast() - @pytest.mark.parametrize("policy", ["sequential", "star"]) + @pytest.mark.parametrize( + "policy", + ["sequential", "star"], + ) def test_prepare(self, adata_space_rotate: AnnData, policy): expected_keys = { "sequential": [("0", "1"), ("1", "2")], @@ -181,6 +185,15 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, problem = problem.solve(max_iterations=2) + @pytest.mark.fast() + def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): + problem = GWProblem(adata=adata_time) + problem = problem.prepare( + a=marginal_keys[0], b=marginal_keys[1], key="time", policy="sequential", x_attr="X_pca", y_attr="X_pca" + ) + for key in problem: + _assert_marginals_set(adata_time, problem, key, marginal_keys) + @pytest.mark.fast() @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index b0e55c535..8a7b582d2 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -22,6 +22,7 @@ from moscot.base.output import BaseSolverOutput from moscot.base.problems import OTProblem from moscot.problems.generic import SinkhornProblem +from tests._utils import _assert_marginals_set from tests.problems.conftest import ( geometry_args, lin_prob_args, @@ -37,15 +38,14 @@ class TestSinkhornProblem: @pytest.mark.fast() @pytest.mark.parametrize("policy", ["sequential", "star"]) - def test_prepare(self, adata_time: AnnData, policy): + def test_prepare(self, adata_time: AnnData, policy, marginal_keys): expected_keys = {"sequential": [(0, 1), (1, 2)], "star": [(1, 0), (2, 0)]} problem = SinkhornProblem(adata=adata_time) - assert len(problem) == 0 assert problem.problems == {} assert problem.solutions == {} - problem = problem.prepare(key="time", policy=policy, reference=0) + problem = problem.prepare(key="time", policy=policy, reference=0, a=marginal_keys[0], b=marginal_keys[1]) assert isinstance(problem.problems, dict) assert len(problem.problems) == len(expected_keys[policy]) @@ -54,16 +54,21 @@ def test_prepare(self, adata_time: AnnData, policy): assert key in expected_keys[policy] assert isinstance(problem[key], OTProblem) - def test_solve_balanced(self, adata_time: AnnData): + _assert_marginals_set(adata_time, problem, key, marginal_keys) + + def test_solve_balanced(self, adata_time: AnnData, marginal_keys): eps = 0.5 expected_keys = [(0, 1), (1, 2)] problem = SinkhornProblem(adata=adata_time) - problem = problem.prepare(key="time") + problem = problem.prepare(key="time", a=marginal_keys[0], b=marginal_keys[1]) problem = problem.solve(epsilon=eps) for key, subsol in problem.solutions.items(): assert isinstance(subsol, BaseSolverOutput) assert key in expected_keys + assert subsol.converged + assert np.allclose(subsol.a, problem[key].a) + assert np.allclose(subsol.b, problem[key].b) @pytest.mark.fast() @pytest.mark.parametrize( diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 394adc2ef..fed68cf6f 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -57,9 +57,8 @@ def test_solve_balanced(self, adata_time_barcodes: AnnData): ) problem = problem.solve(epsilon=eps) - for key, subsol in problem.solutions.items(): + for _, subsol in problem.solutions.items(): assert isinstance(subsol, BaseSolverOutput) - assert key == key def test_solve_unbalanced(self, adata_time_barcodes: AnnData): taus = [9e-1, 1e-2]