Skip to content

Commit

Permalink
fix translation and xy being set when alpha is 1 then improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Apr 18, 2024
1 parent 0ef9177 commit 80c3b74
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ def solve(
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)
if alpha == 1.0 and self.xy is not None:
raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.")

self._solver = solver_class(**init_kwargs)

Expand Down
17 changes: 9 additions & 8 deletions src/moscot/problems/cross_modality/_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ def prepare(
xy = {} # type: ignore[var-annotated]
else:
xy, kwargs = handle_joint_attr(joint_attr, kwargs)
_, dim_src = getattr(self.adata_src, xy["x_attr"])[xy["x_key"]].shape
_, dim_tgt = getattr(self.adata_tgt, xy["y_attr"])[xy["y_key"]].shape
if dim_src != dim_tgt:
raise ValueError(
f"The dimensions of `joint_attr` do not match. "
f"The joint attribute in the source distribution has dimension {dim_src}, "
f"while the joint attribute in the target distribution has dimension {dim_tgt}."
)
if "x_key" in xy and "y_key" in xy:
_, dim_src = getattr(self.adata_src, xy["x_attr"])[xy["x_key"]].shape
_, dim_tgt = getattr(self.adata_tgt, xy["y_attr"])[xy["y_key"]].shape
if dim_src != dim_tgt:
raise ValueError(
f"The dimensions of `joint_attr` do not match. "
f"The joint attribute in the source distribution has dimension {dim_src}, "
f"while the joint attribute in the target distribution has dimension {dim_tgt}."
)
xy, x, y = handle_cost(
xy=xy, x=self._src_attr, y=self._tgt_attr, cost=cost, cost_kwargs=cost_kwargs, **kwargs # type: ignore[arg-type]
)
Expand Down
24 changes: 15 additions & 9 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from typing import Any, Literal, Mapping, Optional, Tuple

import pytest
Expand Down Expand Up @@ -82,14 +83,16 @@ def test_prepare_external_star_policy(
assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars)

@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
("epsilon", "alpha", "rank", "initializer", "joint_attr", "expect_fail"),
[
(1e-2, 0.9, -1, None),
(2, 0.5, -1, "random"),
(2, 1.0, -1, "rank2"),
(2, 0.1, -1, None),
(2, 1.0, -1, None),
(1.3, 1.0, -1, "random"),
(1e-2, 0.9, -1, None, {"attr": "obsm", "key": "X_pca"}, False),
(2, 0.5, -1, "random", None, True),
(2, 0.5, -1, "random", {"attr": "X"}, False),
(2, 1.0, -1, "rank2", None, False),
(2, 1.0, -1, "rank2", {"attr": "obsm", "key": "X_pca"}, True),
(2, 0.1, -1, None, {"attr": "obsm", "key": "X_pca"}, False),
(2, 1.0, -1, None, None, False),
(1.3, 1.0, -1, "random", None, False),
],
)
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
Expand All @@ -103,6 +106,8 @@ def test_solve_balanced(
src_attr: Mapping[str, str],
tgt_attr: Mapping[str, str],
initializer: Optional[Literal["random", "rank2"]],
joint_attr: Optional[Mapping[str, str]],
expect_fail: bool,
):
adata_src, adata_tgt = adata_translation_split
kwargs = {}
Expand All @@ -111,9 +116,10 @@ def test_solve_balanced(
kwargs["initializer"] = initializer

tp = TranslationProblem(adata_src, adata_tgt)
joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"}
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
context = pytest.raises(ValueError, match="alpha") if expect_fail else nullcontext()
with context:
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for key, subsol in tp.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
Expand Down

0 comments on commit 80c3b74

Please sign in to comment.