Skip to content

Commit

Permalink
Merge branch 'main' into add/sparse-geodesic
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored Apr 16, 2024
2 parents d85fa68 + 0ef9177 commit afcc83f
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
15 changes: 15 additions & 0 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,21 @@ def solve(
"""
solver_class = backends.get_solver(self.problem_kind, backend=backend, return_class=True)
init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
# if linear problem, then alpha is 0.0 by default
# if quadratic problem, then alpha is 1.0 by default
alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0)
if alpha < 0.0 or alpha > 1.0:
raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.")
if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)):
raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.")
if self.problem_kind == "quadratic":
if self.x is None or self.y is None:
raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.")
if alpha != 1.0 and self.xy is None: # means FGW case
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)

self._solver = solver_class(**init_kwargs)

self._solution = self._solver( # type: ignore[misc]
Expand Down
6 changes: 4 additions & 2 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def prepare(

def solve(
self,
alpha: float = 1.0,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
Expand Down Expand Up @@ -622,7 +622,7 @@ def solve(
Parameters
----------
alpha
Parameter in :math:`(0, 1]` that interpolates between the :term:`quadratic term` and
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
Expand Down Expand Up @@ -672,6 +672,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return CompoundProblem.solve(
self, # type: ignore[return-value, arg-type]
alpha=alpha,
Expand Down
12 changes: 10 additions & 2 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,14 @@ def test_prepare_external_star_policy(

@pytest.mark.parametrize(
("epsilon", "alpha", "rank", "initializer"),
[(1e-2, 0.9, -1, None), (2, 0.5, -1, "random"), (2, 0.5, -1, "rank2"), (2, 0.1, -1, None)],
[
(1e-2, 0.9, -1, None),
(2, 0.5, -1, "random"),
(2, 1.0, -1, "rank2"),
(2, 0.1, -1, None),
(2, 1.0, -1, None),
(1.3, 1.0, -1, "random"),
],
)
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
Expand All @@ -104,7 +111,8 @@ def test_solve_balanced(
kwargs["initializer"] = initializer

tp = TranslationProblem(adata_src, adata_tgt)
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr)
joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"}
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for key, subsol in tp.solutions.items():
Expand Down
1 change: 1 addition & 0 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_solve_balanced(
return # TODO(@MUCDK) fix after refactoring
mp = MappingProblem(adataref, adatasp)
mp = mp.prepare(batch_key="batch", sc_attr=sc_attr, var_names=var_names)
alpha = alpha if mp.filtered_vars is not None else 1.0
mp = mp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for prob_key in mp:
Expand Down

0 comments on commit afcc83f

Please sign in to comment.