From 2662022c3f59af1881a6cfb05eeb28280c6faa53 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Sun, 31 Mar 2024 23:00:26 +0200 Subject: [PATCH] change implementation --- src/moscot/base/solver.py | 1 + src/moscot/problems/generic/_generic.py | 23 +++++++++++++++++------ 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index 996aeeaaf..ce0fe48f9 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -210,6 +210,7 @@ def _untag(self, data: TaggedArrayData) -> Dict[str, Any]: raise ValueError("No data specified for the linear term.") if data.x is not None or data.y is not None: logger.warning("Ignoring `x` and `y` data as they are not needed for the linear term.") + # TODO: should this be an error? data_kwargs: Dict[str, Any] = {"xy": data.xy} elif self.problem_kind == "quadratic": if data.x is None or data.y is None: diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 8f9958952..48b62ca21 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -437,8 +437,11 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ + sentence_end = "is not supported in the GWProblem. Use FGWProblem instead." + if kwargs.get("alpha", 1.0) != 1.0: + raise ValueError(f"The 'alpha' parameter {sentence_end}") if self._xy is not None: - raise ValueError("The `xy` cost matrix is not supported for the GWProblem.") + raise ValueError(f"Linear term {sentence_end}") return super().solve( # type: ignore[return-value] alpha=1.0, epsilon=epsilon, @@ -468,7 +471,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class FGWProblem(GWProblem[K, B]): +class FGWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] """Class for solving the :term:`FGW ` problem. Parameters @@ -580,8 +583,8 @@ def prepare( x = set_quad_defaults(x_attr) if "x_callback" not in kwargs else {} y = set_quad_defaults(y_attr) if "y_callback" not in kwargs else {} xy, x, y = handle_cost(xy=xy, x=x, y=y, cost=cost, cost_kwargs=cost_kwargs, **kwargs) # type: ignore[arg-type] - return CompoundProblem.prepare( - self, # type: ignore[return-value, arg-type] + return self.prepare( + self, # type: ignore[return-value] key=key, xy=xy, x=x, @@ -674,8 +677,8 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ - return CompoundProblem.solve( - self, # type: ignore[return-value, arg-type] + return self.solve( + self, # type: ignore[return-value] alpha=alpha, epsilon=epsilon, tau_a=tau_a, @@ -694,3 +697,11 @@ def solve( device=device, **kwargs, ) + + @property + def _base_problem_type(self) -> Type[B]: + return OTProblem # type: ignore[return-value] + + @property + def _valid_policies(self) -> Tuple[Policy_t, ...]: + return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]