From fc857f9bd1cdc194b6b9158225fed24afad85cfc Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 9 Apr 2024 15:08:26 +0200 Subject: [PATCH] add tests where alpha is used --- .../cross_modality/test_translation_problem.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index 527f854b2..a90845d75 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -82,8 +82,15 @@ def test_prepare_external_star_policy( assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars) @pytest.mark.parametrize( - ("epsilon", "rank", "initializer"), - [(1e-2, -1, None), (2, -1, "random"), (2, -1, "rank2"), (2, -1, None)], + ("epsilon", "alpha", "rank", "initializer"), + [ + (1e-2, 0.9, -1, None), + (2, 0.5, -1, "random"), + (2, 1.0, -1, "rank2"), + (2, 0.1, -1, None), + (2, 1.0, -1, None), + (1.3, 1.0, -1, "random"), + ], ) @pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}]) @pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}]) @@ -91,6 +98,7 @@ def test_solve_balanced( self, adata_translation_split: Tuple[AnnData, AnnData], epsilon: float, + alpha: float, rank: int, src_attr: Mapping[str, str], tgt_attr: Mapping[str, str], @@ -103,8 +111,9 @@ def test_solve_balanced( kwargs["initializer"] = initializer tp = TranslationProblem(adata_src, adata_tgt) - tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr) - tp = tp.solve(epsilon=epsilon, rank=rank, **kwargs) + joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"} + tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr) + tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for key, subsol in tp.solutions.items(): assert isinstance(subsol, BaseSolverOutput)