diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index fcd3d8b2..9a819671 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -264,8 +264,8 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] - """Class for solving the :term:`GW ` or :term:`FGW ` problems. +class FGWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] + """Class for solving the :term:`FGW ` problem. Parameters ---------- @@ -281,6 +281,7 @@ def __init__(self, adata: AnnData, **kwargs: Any): def prepare( self, key: str, + joint_attr: Optional[Union[str, Mapping[str, Any]]] = None, x_attr: Optional[Union[str, Mapping[str, Any]]] = None, y_attr: Optional[Union[str, Mapping[str, Any]]] = None, policy: Literal["sequential", "explicit", "star"] = "sequential", @@ -288,13 +289,15 @@ def prepare( cost_kwargs: CostKwargs_t = types.MappingProxyType({}), a: Optional[Union[bool, str]] = None, b: Optional[Union[bool, str]] = None, - subset: Optional[Sequence[Tuple[K, K]]] = None, - reference: Optional[Any] = None, + xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - ) -> "GWProblem[K, B]": + subset: Optional[Sequence[Tuple[K, K]]] = None, + reference: Optional[Any] = None, + ) -> "FGWProblem[K, B]": """Prepare the individual :term:`quadratic subproblems `. .. seealso:: @@ -304,6 +307,16 @@ def prepare( ---------- key Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. + joint_attr + How to get the data for the :term:`linear term` in the :term:`fused ` case: + + - :obj:`None` - run `PCA `_ + on :attr:`~anndata.AnnData.X` is computed. + - :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored. + - :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key + in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`. + + By default, :attr:`tag = 'point_cloud' ` is used. x_attr How to get the data for the source :term:`quadratic term`: @@ -355,6 +368,18 @@ def prepare( :meth:`estimate the marginals `, otherwise use uniform marginals. - :obj:`None` - uniform marginals. + xy_callback + Callback function used to prepare the data in the :term:`linear term`. + x_callback + Callback function used to prepare the data in the source :term:`quadratic term`. + y_callback + Callback function used to prepare the data in the target :term:`quadratic term`. + xy_callback_kwargs + Keyword arguments for the ``xy_callback``. + x_callback_kwargs + Keyword arguments for the ``x_callback``. + y_callback_kwargs + Keyword arguments for the ``y_callback``. Returns ------- @@ -369,15 +394,16 @@ def prepare( self.batch_key = key x = set_quad_defaults(x_attr) if x_callback is None else {} y = set_quad_defaults(y_attr) if y_callback is None else {} - + xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs) xy, x, y = handle_cost( - xy={}, + xy=xy, x=x, y=y, cost=cost, - cost_kwargs=cost_kwargs, x_callback=x_callback, y_callback=y_callback, + xy_callback=xy_callback, + cost_kwargs=cost_kwargs, ) return super().prepare( # type: ignore[return-value] key=key, @@ -387,16 +413,19 @@ def prepare( policy=policy, a=a, b=b, + reference=reference, + subset=subset, x_callback=x_callback, y_callback=y_callback, + xy_callback=xy_callback, x_callback_kwargs=x_callback_kwargs, y_callback_kwargs=y_callback_kwargs, - subset=subset, - reference=reference, + xy_callback_kwargs=xy_callback_kwargs, ) def solve( self, + alpha: float = 0.5, epsilon: float = 1e-3, tau_a: float = 1.0, tau_b: float = 1.0, @@ -413,7 +442,7 @@ def solve( linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, - ) -> "GWProblem[K,B]": + ) -> "FGWProblem[K,B]": r"""Solve the individual :term:`quadratic subproblems `. .. seealso: @@ -424,6 +453,10 @@ def solve( Parameters ---------- + alpha + Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and + the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while + :math:`\alpha \to 0` corresponds to the pure :term:`linear problem`. epsilon :term:`Entropic regularization`. tau_a @@ -471,8 +504,10 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ - return super().solve( # type: ignore[return-value] - alpha=1.0, + if alpha == 1.0: + raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.") + return super().solve( + alpha=alpha, epsilon=epsilon, tau_a=tau_a, tau_b=tau_b, @@ -489,7 +524,7 @@ def solve( linear_solver_kwargs=linear_solver_kwargs, device=device, **kwargs, - ) + ) # type: ignore[return-value] @property def _base_problem_type(self) -> Type[B]: @@ -500,8 +535,8 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class FGWProblem(GWProblem[K, B]): - """Class for solving the :term:`FGW ` problem. +class GWProblem(FGWProblem[K, B]): + """Class for solving the :term:`GW ` or :term:`FGW ` problems. Parameters ---------- @@ -514,7 +549,6 @@ class FGWProblem(GWProblem[K, B]): def prepare( self, key: str, - joint_attr: Optional[Union[str, Mapping[str, Any]]] = None, x_attr: Optional[Union[str, Mapping[str, Any]]] = None, y_attr: Optional[Union[str, Mapping[str, Any]]] = None, policy: Literal["sequential", "explicit", "star"] = "sequential", @@ -522,15 +556,13 @@ def prepare( cost_kwargs: CostKwargs_t = types.MappingProxyType({}), a: Optional[Union[bool, str]] = None, b: Optional[Union[bool, str]] = None, - xy_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, + subset: Optional[Sequence[Tuple[K, K]]] = None, + reference: Optional[Any] = None, x_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, y_callback: Optional[Union[Literal["local-pca"], Callback_t]] = None, - xy_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), x_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), y_callback_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - subset: Optional[Sequence[Tuple[K, K]]] = None, - reference: Optional[Any] = None, - ) -> "FGWProblem[K, B]": + ) -> "GWProblem[K, B]": """Prepare the individual :term:`quadratic subproblems `. .. seealso:: @@ -540,16 +572,6 @@ def prepare( ---------- key Key in :attr:`~anndata.AnnData.obs` for the :class:`~moscot.utils.subset_policy.SubsetPolicy`. - joint_attr - How to get the data for the :term:`linear term` in the :term:`fused ` case: - - - :obj:`None` - run `PCA `_ - on :attr:`~anndata.AnnData.X` is computed. - - :class:`str` - a key in :attr:`~anndata.AnnData.obsm` where the data is stored. - - :class:`dict` - it should contain ``'attr'`` and ``'key'``, the attribute and the key - in :class:`~anndata.AnnData`, and optionally ``'tag'``, one of :class:`~moscot.utils.tagged_array.Tag`. - - By default, :attr:`tag = 'point_cloud' ` is used. x_attr How to get the data for the source :term:`quadratic term`: @@ -601,24 +623,6 @@ def prepare( :meth:`estimate the marginals `, otherwise use uniform marginals. - :obj:`None` - uniform marginals. - xy - Data for the :term:`linear term`. - x - Data for the source :term:`quadratic term`. - y - Data for the target :term:`quadratic term`. - xy_callback - Callback function used to prepare the data in the :term:`linear term`. - x_callback - Callback function used to prepare the data in the source :term:`quadratic term`. - y_callback - Callback function used to prepare the data in the target :term:`quadratic term`. - xy_callback_kwargs - Keyword arguments for the ``xy_callback``. - x_callback_kwargs - Keyword arguments for the ``x_callback``. - y_callback_kwargs - Keyword arguments for the ``y_callback``. Returns ------- @@ -630,42 +634,25 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'quadratic'``. """ - self.batch_key = key - x = set_quad_defaults(x_attr) if x_callback is None else {} - y = set_quad_defaults(y_attr) if y_callback is None else {} - xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs) - xy, x, y = handle_cost( - xy=xy, - x=x, - y=y, - cost=cost, - x_callback=x_callback, - y_callback=y_callback, - xy_callback=xy_callback, - cost_kwargs=cost_kwargs, - ) - return CompoundProblem.prepare( - self, # type: ignore[return-value, arg-type] + return super().prepare( # type: ignore[return-value] key=key, - xy=xy, - x=x, - y=y, policy=policy, a=a, b=b, - reference=reference, - subset=subset, # type: ignore[arg-type] + x_attr=x_attr, + y_attr=y_attr, + cost=cost, + cost_kwargs=cost_kwargs, x_callback=x_callback, y_callback=y_callback, - xy_callback=xy_callback, x_callback_kwargs=x_callback_kwargs, y_callback_kwargs=y_callback_kwargs, - xy_callback_kwargs=xy_callback_kwargs, + subset=subset, + reference=reference, ) def solve( self, - alpha: float = 0.5, epsilon: float = 1e-3, tau_a: float = 1.0, tau_b: float = 1.0, @@ -682,7 +669,7 @@ def solve( linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, - ) -> "FGWProblem[K,B]": + ) -> "GWProblem[K,B]": r"""Solve the individual :term:`quadratic subproblems `. .. seealso: @@ -693,10 +680,6 @@ def solve( Parameters ---------- - alpha - Parameter in :math:`(0, 1)` that interpolates between the :term:`quadratic term` and - the :term:`linear term`. :math:`\alpha = 1` corresponds to the pure :term:`Gromov-Wasserstein` problem while - :math:`\alpha \to 0` corresponds to the pure :term:`linear problem`. epsilon :term:`Entropic regularization`. tau_a @@ -744,11 +727,8 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ - if alpha == 1.0: - raise ValueError("The `FGWProblem` is equivalent to the `GWProblem` when `alpha=1.0`.") - return CompoundProblem.solve( - self, # type: ignore[return-value, arg-type] - alpha=alpha, + return super().solve( # type: ignore[return-value] + alpha=1.0, epsilon=epsilon, tau_a=tau_a, tau_b=tau_b, @@ -767,14 +747,6 @@ def solve( **kwargs, ) - @property - def _base_problem_type(self) -> Type[B]: - return OTProblem # type: ignore[return-value] - - @property - def _valid_policies(self) -> Tuple[Policy_t, ...]: - return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] - class GENOTLinProblem(CondOTProblem): """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""