From 80c3b74245a734aa91b02b01117ef9c3f6fadb07 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 18 Apr 2024 12:01:20 +0200 Subject: [PATCH] fix translation and xy being set when alpha is 1 then improve tests --- src/moscot/base/problems/problem.py | 2 ++ .../problems/cross_modality/_translation.py | 17 ++++++------- .../test_translation_problem.py | 24 ++++++++++++------- 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index e2b47cabb..b249a46f6 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -417,6 +417,8 @@ def solve( raise ValueError( "`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class." ) + if alpha == 1.0 and self.xy is not None: + raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.") self._solver = solver_class(**init_kwargs) diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index ff91f015d..7410c8db2 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -164,14 +164,15 @@ def prepare( xy = {} # type: ignore[var-annotated] else: xy, kwargs = handle_joint_attr(joint_attr, kwargs) - _, dim_src = getattr(self.adata_src, xy["x_attr"])[xy["x_key"]].shape - _, dim_tgt = getattr(self.adata_tgt, xy["y_attr"])[xy["y_key"]].shape - if dim_src != dim_tgt: - raise ValueError( - f"The dimensions of `joint_attr` do not match. " - f"The joint attribute in the source distribution has dimension {dim_src}, " - f"while the joint attribute in the target distribution has dimension {dim_tgt}." - ) + if "x_key" in xy and "y_key" in xy: + _, dim_src = getattr(self.adata_src, xy["x_attr"])[xy["x_key"]].shape + _, dim_tgt = getattr(self.adata_tgt, xy["y_attr"])[xy["y_key"]].shape + if dim_src != dim_tgt: + raise ValueError( + f"The dimensions of `joint_attr` do not match. " + f"The joint attribute in the source distribution has dimension {dim_src}, " + f"while the joint attribute in the target distribution has dimension {dim_tgt}." + ) xy, x, y = handle_cost( xy=xy, x=self._src_attr, y=self._tgt_attr, cost=cost, cost_kwargs=cost_kwargs, **kwargs # type: ignore[arg-type] ) diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index a90845d75..f5199877a 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -1,3 +1,4 @@ +from contextlib import nullcontext from typing import Any, Literal, Mapping, Optional, Tuple import pytest @@ -82,14 +83,16 @@ def test_prepare_external_star_policy( assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars) @pytest.mark.parametrize( - ("epsilon", "alpha", "rank", "initializer"), + ("epsilon", "alpha", "rank", "initializer", "joint_attr", "expect_fail"), [ - (1e-2, 0.9, -1, None), - (2, 0.5, -1, "random"), - (2, 1.0, -1, "rank2"), - (2, 0.1, -1, None), - (2, 1.0, -1, None), - (1.3, 1.0, -1, "random"), + (1e-2, 0.9, -1, None, {"attr": "obsm", "key": "X_pca"}, False), + (2, 0.5, -1, "random", None, True), + (2, 0.5, -1, "random", {"attr": "X"}, False), + (2, 1.0, -1, "rank2", None, False), + (2, 1.0, -1, "rank2", {"attr": "obsm", "key": "X_pca"}, True), + (2, 0.1, -1, None, {"attr": "obsm", "key": "X_pca"}, False), + (2, 1.0, -1, None, None, False), + (1.3, 1.0, -1, "random", None, False), ], ) @pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}]) @@ -103,6 +106,8 @@ def test_solve_balanced( src_attr: Mapping[str, str], tgt_attr: Mapping[str, str], initializer: Optional[Literal["random", "rank2"]], + joint_attr: Optional[Mapping[str, str]], + expect_fail: bool, ): adata_src, adata_tgt = adata_translation_split kwargs = {} @@ -111,9 +116,10 @@ def test_solve_balanced( kwargs["initializer"] = initializer tp = TranslationProblem(adata_src, adata_tgt) - joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"} tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr) - tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) + context = pytest.raises(ValueError, match="alpha") if expect_fail else nullcontext() + with context: + tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for key, subsol in tp.solutions.items(): assert isinstance(subsol, BaseSolverOutput)