Skip to content

Commit

Permalink
change implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Mar 31, 2024
1 parent 9afacd4 commit 2662022
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/moscot/base/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def _untag(self, data: TaggedArrayData) -> Dict[str, Any]:
raise ValueError("No data specified for the linear term.")
if data.x is not None or data.y is not None:
logger.warning("Ignoring `x` and `y` data as they are not needed for the linear term.")
# TODO: should this be an error?
data_kwargs: Dict[str, Any] = {"xy": data.xy}
elif self.problem_kind == "quadratic":
if data.x is None or data.y is None:
Expand Down
23 changes: 17 additions & 6 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,11 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
sentence_end = "is not supported in the GWProblem. Use FGWProblem instead."
if kwargs.get("alpha", 1.0) != 1.0:
raise ValueError(f"The 'alpha' parameter {sentence_end}")
if self._xy is not None:
raise ValueError("The `xy` cost matrix is not supported for the GWProblem.")
raise ValueError(f"Linear term {sentence_end}")
return super().solve( # type: ignore[return-value]
alpha=1.0,
epsilon=epsilon,
Expand Down Expand Up @@ -468,7 +471,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class FGWProblem(GWProblem[K, B]):
class FGWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`FGW <fused Gromov-Wasserstein>` problem.
Parameters
Expand Down Expand Up @@ -580,8 +583,8 @@ def prepare(
x = set_quad_defaults(x_attr) if "x_callback" not in kwargs else {}
y = set_quad_defaults(y_attr) if "y_callback" not in kwargs else {}
xy, x, y = handle_cost(xy=xy, x=x, y=y, cost=cost, cost_kwargs=cost_kwargs, **kwargs) # type: ignore[arg-type]
return CompoundProblem.prepare(
self, # type: ignore[return-value, arg-type]
return self.prepare(
self, # type: ignore[return-value]
key=key,
xy=xy,
x=x,
Expand Down Expand Up @@ -674,8 +677,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
return CompoundProblem.solve(
self, # type: ignore[return-value, arg-type]
return self.solve(
self, # type: ignore[return-value]
alpha=alpha,
epsilon=epsilon,
tau_a=tau_a,
Expand All @@ -694,3 +697,11 @@ def solve(
device=device,
**kwargs,
)

@property
def _base_problem_type(self) -> Type[B]:
return OTProblem # type: ignore[return-value]

@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]

0 comments on commit 2662022

Please sign in to comment.