Skip to content

Commit

Permalink
refactor the gwproblem and fgwproblem inheritance
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Dec 10, 2024
1 parent ef50b55 commit 45509e0
Showing 1 changed file with 64 additions and 92 deletions.
156 changes: 64 additions & 92 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
class FGWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc]
"""Class for solving the :term:`FGW <fused Gromov-Wasserstein>` problem.
Parameters
----------
Expand All @@ -281,20 +281,23 @@ def __init__(self, adata: AnnData, **kwargs: Any):
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
x_attr: Optional[Union[str, Mapping[str, Any]]] = None,
y_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFnMap_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> "GWProblem[K, B]":
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
) -> "FGWProblem[K, B]":
"""Prepare the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso::
Expand All @@ -304,6 +307,16 @@ def prepare(
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
joint_attr
How to get the data for the :term:`linear term` in the :term:`fused <fused Gromov-Wasserstein>` case:
- :obj:`None` - run `PCA <https://en.wikipedia.org/wiki/Principal_component_analysis>`_
on :attr:`~anndata.AnnData.X` is computed.
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
x_attr
How to get the data for the source :term:`quadratic term`:
Expand Down Expand Up @@ -355,6 +368,18 @@ def prepare(
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
xy_callback
Callback function used to prepare the data in the :term:`linear term`.
x_callback
Callback function used to prepare the data in the source :term:`quadratic term`.
y_callback
Callback function used to prepare the data in the target :term:`quadratic term`.
xy_callback_kwargs
Keyword arguments for the ``xy_callback``.
x_callback_kwargs
Keyword arguments for the ``x_callback``.
y_callback_kwargs
Keyword arguments for the ``y_callback``.
Returns
-------
Expand All @@ -369,15 +394,16 @@ def prepare(
self.batch_key = key
x = set_quad_defaults(x_attr) if x_callback is None else {}
y = set_quad_defaults(y_attr) if y_callback is None else {}

xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs)
xy, x, y = handle_cost(
xy={},
xy=xy,
x=x,
y=y,
cost=cost,
cost_kwargs=cost_kwargs,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
cost_kwargs=cost_kwargs,
)
return super().prepare( # type: ignore[return-value]
key=key,
Expand All @@ -387,16 +413,19 @@ def prepare(
policy=policy,
a=a,
b=b,
reference=reference,
subset=subset,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
subset=subset,
reference=reference,
xy_callback_kwargs=xy_callback_kwargs,
)

def solve(
self,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
Expand All @@ -413,7 +442,7 @@ def solve(
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "GWProblem[K,B]":
) -> "FGWProblem[K,B]":
r"""Solve the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso:
Expand All @@ -424,6 +453,10 @@ def solve(
Parameters
----------
alpha
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
:term:`Entropic regularization`.
tau_a
Expand Down Expand Up @@ -471,8 +504,10 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
return super().solve( # type: ignore[return-value]
alpha=1.0,
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return super().solve(
alpha=alpha,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
Expand All @@ -489,7 +524,7 @@ def solve(
linear_solver_kwargs=linear_solver_kwargs,
device=device,
**kwargs,
)
) # type: ignore[return-value]

@property
def _base_problem_type(self) -> Type[B]:
Expand All @@ -500,8 +535,8 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class FGWProblem(GWProblem[K, B]):
"""Class for solving the :term:`FGW <fused Gromov-Wasserstein>` problem.
class GWProblem(FGWProblem[K, B]):
"""Class for solving the :term:`GW <Gromov-Wasserstein>` or :term:`FGW <fused Gromov-Wasserstein>` problems.
Parameters
----------
Expand All @@ -514,23 +549,20 @@ class FGWProblem(GWProblem[K, B]):
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
x_attr: Optional[Union[str, Mapping[str, Any]]] = None,
y_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "explicit", "star"] = "sequential",
cost: OttCostFnMap_t = "sq_euclidean",
cost_kwargs: CostKwargs_t = types.MappingProxyType({}),
a: Optional[Union[bool, str]] = None,
b: Optional[Union[bool, str]] = None,
xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None,
xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
subset: Optional[Sequence[Tuple[K, K]]] = None,
reference: Optional[Any] = None,
) -> "FGWProblem[K, B]":
) -> "GWProblem[K, B]":
"""Prepare the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso::
Expand All @@ -540,16 +572,6 @@ def prepare(
----------
key
Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`.
joint_attr
How to get the data for the :term:`linear term` in the :term:`fused <fused Gromov-Wasserstein>` case:
- :obj:`None` - run `PCA <https://en.wikipedia.org/wiki/Principal_component_analysis>`_
on :attr:`~anndata.AnnData.X` is computed.
- :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored.
- :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key
in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`.
By default, :attr:`tag = 'point_cloud' <moscot.utils.tagged_array.Tag.POINT_CLOUD>` is used.
x_attr
How to get the data for the source :term:`quadratic term`:
Expand Down Expand Up @@ -601,24 +623,6 @@ def prepare(
:meth:`estimate the marginals <moscot.base.problems.OTProblem.estimate_marginals>`,
otherwise use uniform marginals.
- :obj:`None` - uniform marginals.
xy
Data for the :term:`linear term`.
x
Data for the source :term:`quadratic term`.
y
Data for the target :term:`quadratic term`.
xy_callback
Callback function used to prepare the data in the :term:`linear term`.
x_callback
Callback function used to prepare the data in the source :term:`quadratic term`.
y_callback
Callback function used to prepare the data in the target :term:`quadratic term`.
xy_callback_kwargs
Keyword arguments for the ``xy_callback``.
x_callback_kwargs
Keyword arguments for the ``x_callback``.
y_callback_kwargs
Keyword arguments for the ``y_callback``.
Returns
-------
Expand All @@ -630,42 +634,25 @@ def prepare(
- :attr:`stage` - set to ``'prepared'``.
- :attr:`problem_kind` - set to ``'quadratic'``.
"""
self.batch_key = key
x = set_quad_defaults(x_attr) if x_callback is None else {}
y = set_quad_defaults(y_attr) if y_callback is None else {}
xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs)
xy, x, y = handle_cost(
xy=xy,
x=x,
y=y,
cost=cost,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
cost_kwargs=cost_kwargs,
)
return CompoundProblem.prepare(
self, # type: ignore[return-value, arg-type]
return super().prepare( # type: ignore[return-value]
key=key,
xy=xy,
x=x,
y=y,
policy=policy,
a=a,
b=b,
reference=reference,
subset=subset, # type: ignore[arg-type]
x_attr=x_attr,
y_attr=y_attr,
cost=cost,
cost_kwargs=cost_kwargs,
x_callback=x_callback,
y_callback=y_callback,
xy_callback=xy_callback,
x_callback_kwargs=x_callback_kwargs,
y_callback_kwargs=y_callback_kwargs,
xy_callback_kwargs=xy_callback_kwargs,
subset=subset,
reference=reference,
)

def solve(
self,
alpha: float = 0.5,
epsilon: float = 1e-3,
tau_a: float = 1.0,
tau_b: float = 1.0,
Expand All @@ -682,7 +669,7 @@ def solve(
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
**kwargs: Any,
) -> "FGWProblem[K,B]":
) -> "GWProblem[K,B]":
r"""Solve the individual :term:`quadratic subproblems <quadratic problem>`.
.. seealso:
Expand All @@ -693,10 +680,6 @@ def solve(
Parameters
----------
alpha
Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and
the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while
:math:`\alpha \to 0` corresponds to the pure :term:`linear problem`.
epsilon
:term:`Entropic regularization`.
tau_a
Expand Down Expand Up @@ -744,11 +727,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if alpha == 1.0:
raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.")
return CompoundProblem.solve(
self, # type: ignore[return-value, arg-type]
alpha=alpha,
return super().solve( # type: ignore[return-value]
alpha=1.0,
epsilon=epsilon,
tau_a=tau_a,
tau_b=tau_b,
Expand All @@ -767,14 +747,6 @@ def solve(
**kwargs,
)

@property
def _base_problem_type(self) -> Type[B]:
return OTProblem # type: ignore[return-value]

@property
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GENOTLinProblem(CondOTProblem):
"""Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""
Expand Down

0 comments on commit 45509e0

Please sign in to comment.