Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TranslationProblem and xy being set when alpha is 1 then improve tests #689

Merged
merged 5 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ def solve(
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)
if alpha == 1.0 and self.xy is not None:
raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.")

self._solver = solver_class(**init_kwargs)

Expand Down
17 changes: 9 additions & 8 deletions src/moscot/problems/cross_modality/_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,15 @@ def prepare(
xy = {} # type: ignore[var-annotated]
else:
xy, kwargs = handle_joint_attr(joint_attr, kwargs)
_, dim_src = getattr(self.adata_src, xy["x_attr"])[xy["x_key"]].shape
_, dim_tgt = getattr(self.adata_tgt, xy["y_attr"])[xy["y_key"]].shape
if dim_src != dim_tgt:
raise ValueError(
f"The dimensions of `joint_attr` do not match. "
f"The joint attribute in the source distribution has dimension {dim_src}, "
f"while the joint attribute in the target distribution has dimension {dim_tgt}."
)
if "x_key" in xy and "y_key" in xy:
_, dim_src = getattr(self.adata_src, xy["x_attr"])[xy["x_key"]].shape
_, dim_tgt = getattr(self.adata_tgt, xy["y_attr"])[xy["y_key"]].shape
if dim_src != dim_tgt:
raise ValueError(
f"The dimensions of `joint_attr` do not match. "
f"The joint attribute in the source distribution has dimension {dim_src}, "
f"while the joint attribute in the target distribution has dimension {dim_tgt}."
)
xy, x, y = handle_cost(
xy=xy, x=self._src_attr, y=self._tgt_attr, cost=cost, cost_kwargs=cost_kwargs, **kwargs # type: ignore[arg-type]
)
Expand Down
12 changes: 8 additions & 4 deletions tests/problems/base/test_general_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData):
xy={"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"},
x={"attr": "X"},
y={"attr": "X"},
).solve(epsilon=5e-1)
).solve(epsilon=5e-1, alpha=0.5)

assert isinstance(prob.solution, BaseSolverOutput)

Expand Down Expand Up @@ -64,7 +64,9 @@ def test_set_xy(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_mat
prob.set_xy(cost_matrix, tag=tag)
assert isinstance(prob.xy.data_src, np.ndarray)
assert prob.xy.data_tgt is None
prob = prob.solve(epsilon=1.0, max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
prob = prob.solve(
epsilon=1.0, max_iterations=5, alpha=0.5
) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
np.testing.assert_equal(prob.xy.data_src, cost_matrix.to_numpy())

@pytest.mark.parametrize("tag", ["cost_matrix", "kernel"])
Expand All @@ -83,7 +85,9 @@ def test_set_x(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matr
assert isinstance(prob.x.data_src, np.ndarray)
assert prob.x.data_tgt is None

prob = prob.solve(epsilon=1.0, max_iterations=5) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
prob = prob.solve(
epsilon=1.0, max_iterations=5, alpha=0.5
) # TODO(@MUCDK) once fixed in OTT-JAX test for scale_cost
np.testing.assert_equal(prob.x.data_src, cost_matrix.to_numpy())

@pytest.mark.parametrize("tag", ["cost_matrix", "kernel"])
Expand All @@ -102,7 +106,7 @@ def test_set_y(self, adata_x: AnnData, adata_y: AnnData, tag: Literal["cost_matr
assert isinstance(prob.y.data_src, np.ndarray)
assert prob.y.data_tgt is None

prob = prob.solve(epsilon=1.0, max_iterations=5)
prob = prob.solve(epsilon=1.0, max_iterations=5, alpha=0.5)
np.testing.assert_equal(prob.y.data_src, cost_matrix.to_numpy())

def test_set_xy_change_problem_kind(self, adata_x: AnnData, adata_y: AnnData):
Expand Down
4 changes: 2 additions & 2 deletions tests/problems/cross_modality/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ def test_translation_foo(
):
adata_src, adata_tgt = adata_translation_split
expected_keys = {(i, "ref") for i in adata_src.obs["batch"].cat.categories}

alpha = 1.0 if joint_attr is None else 0.5
tp = (
TranslationProblem(adata_src, adata_tgt)
.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
.solve()
.solve(alpha=alpha)
)
for src, tgt in expected_keys:
trans_forward = tp.translate(source=src, target=tgt, forward=True)
Expand Down
24 changes: 15 additions & 9 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import nullcontext
from typing import Any, Literal, Mapping, Optional, Tuple

import pytest
Expand Down Expand Up @@ -82,14 +83,16 @@ def test_prepare_external_star_policy(
assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars)

@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
("epsilon", "alpha", "rank", "initializer", "joint_attr", "expect_fail"),
[
(1e-2, 0.9, -1, None),
(2, 0.5, -1, "random"),
(2, 1.0, -1, "rank2"),
(2, 0.1, -1, None),
(2, 1.0, -1, None),
(1.3, 1.0, -1, "random"),
(1e-2, 0.9, -1, None, {"attr": "obsm", "key": "X_pca"}, False),
(2, 0.5, -1, "random", None, True),
(2, 0.5, -1, "random", {"attr": "X"}, False),
(2, 1.0, -1, "rank2", None, False),
(2, 1.0, -1, "rank2", {"attr": "obsm", "key": "X_pca"}, True),
(2, 0.1, -1, None, {"attr": "obsm", "key": "X_pca"}, False),
(2, 1.0, -1, None, None, False),
(1.3, 1.0, -1, "random", None, False),
],
)
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
Expand All @@ -103,6 +106,8 @@ def test_solve_balanced(
src_attr: Mapping[str, str],
tgt_attr: Mapping[str, str],
initializer: Optional[Literal["random", "rank2"]],
joint_attr: Optional[Mapping[str, str]],
expect_fail: bool,
):
adata_src, adata_tgt = adata_translation_split
kwargs = {}
Expand All @@ -111,9 +116,10 @@ def test_solve_balanced(
kwargs["initializer"] = initializer

tp = TranslationProblem(adata_src, adata_tgt)
joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"}
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)
context = pytest.raises(ValueError, match="alpha") if expect_fail else nullcontext()
with context:
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for key, subsol in tp.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
Expand Down
Loading