From b9b66d5a480c121cab19850b7708caa49b07af01 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 24 May 2024 10:18:44 +0200 Subject: [PATCH] (chore): fix docs --- docs/conf.py | 12 +++++++ docs/developer.rst | 6 +++- docs/notebooks | 2 +- docs/user.rst | 1 + src/moscot/backends/ott/output.py | 10 +++--- src/moscot/base/output.py | 44 ++++++++++++-------------- src/moscot/base/problems/problem.py | 16 ++++------ src/moscot/base/solver.py | 6 ++-- src/moscot/problems/generic/_mixins.py | 8 ++--- src/moscot/utils/tagged_array.py | 26 +++++---------- 10 files changed, 65 insertions(+), 66 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 5fa891ac2..483626db6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -63,12 +63,24 @@ nitpicky = True nitpick_ignore = [ ("py:class", "numpy.float64"), + # see: https://github.com/numpy/numpydoc/issues/275 + ("py:class", "None. Remove all items from D."), + ("py:class", "a set-like object providing a view on D's items"), + ("py:class", "a set-like object providing a view on D's keys"), + ("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501 + ("py:class", "None. Update D from dict/iterable E and F."), + ("py:class", "an object providing a view on D's values"), + ("py:class", "a shallow copy of D"), ] # TODO(michalk8): remove once typing has been cleaned-up nitpick_ignore_regex = [ (r"py:class", r"moscot\..*(K|B|O)"), (r"py:class", r"numpy\._typing.*"), (r"py:class", r"moscot\..*Protocol.*"), + ( + r"py:class", + r"moscot.base.output.BaseSolverOutput", + ), # https://github.com/sphinx-doc/sphinx/issues/10974 means there is simply no way around this with generics ] diff --git a/docs/developer.rst b/docs/developer.rst index da236defa..26466b03a 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -12,6 +12,8 @@ Backends backends.ott.GWSolver backends.ott.OTTOutput backends.ott.GraphOTTOutput + backends.ott.GENOTLinSolver + backends.ott.output.OTTNeuralOutput backends.utils.get_solver backends.utils.get_available_backends @@ -44,6 +46,7 @@ Problems problems.BaseCompoundProblem problems.CompoundProblem cost.BaseCost + problems.CondOTProblem Mixins ^^^^^^ @@ -62,7 +65,6 @@ Solvers solver.BaseSolver solver.OTSolver - output.BaseDiscreteSolverOutput Output ^^^^^^ @@ -100,6 +102,8 @@ Miscellaneous data.apoptosis_markers tagged_array.TaggedArray tagged_array.Tag + tagged_array.DistributionCollection + tagged_array.DistributionContainer .. currentmodule:: moscot.base.problems .. autosummary:: diff --git a/docs/notebooks b/docs/notebooks index c38161a78..fd6959de0 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit c38161a78e986e303d9622647af3242b1221077f +Subproject commit fd6959de020c7eb45c10bab8c5890e61eba9cc79 diff --git a/docs/user.rst b/docs/user.rst index 8c6b2d59a..f4291892a 100644 --- a/docs/user.rst +++ b/docs/user.rst @@ -27,6 +27,7 @@ Generic Problems generic.SinkhornProblem generic.GWProblem generic.FGWProblem + generic.GENOTLinProblem Plotting ~~~~~~~~ diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index a5b7b4215..0e244f9c7 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -400,12 +400,10 @@ def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: Parameters ---------- - x : ArrayLike - cond : Optional[ArrayLike], optional - - Returns - ------- - ArrayLike + x + Distribution to push. + cond + Condition of conditional neural OT. Raises ------ diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 4750fa0d7..590c89bfe 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -1,18 +1,10 @@ +from __future__ import annotations + import abc import copy import functools from abc import abstractmethod -from typing import ( - Any, - Callable, - Iterable, - List, - Literal, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, Iterable, Literal, Optional, Union import numpy as np import scipy.sparse as sp @@ -23,19 +15,25 @@ __all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"] -OutputClass = TypeVar("OutputClass", bound="BaseSolverOutput") # use string - class BaseSolverOutput(abc.ABC): """Base class for all solver outputs.""" + @abc.abstractmethod + def pull(self, x: ArrayLike, **kwargs) -> ArrayLike: + """Pull the solution based on a condition.""" + + @abc.abstractmethod + def push(self, x: ArrayLike, **kwargs) -> ArrayLike: + """Push the solution based on a condition.""" + @property @abc.abstractmethod - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: """Shape of the problem.""" @abc.abstractmethod - def to(self: OutputClass, device: Optional[Device_t] = None) -> OutputClass: + def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput: """Transfer self to another compute device. Parameters @@ -83,7 +81,7 @@ def converged(self) -> bool: @property @abc.abstractmethod - def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: + def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: """:term:`Dual potentials` :math:`f` and :math:`g`. Only valid for the :term:`Sinkhorn` algorithm. @@ -91,7 +89,7 @@ def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: @property @abc.abstractmethod - def shape(self) -> Tuple[int, int]: + def shape(self) -> tuple[int, int]: """Shape of the :attr:`transport_matrix`.""" @property @@ -179,7 +177,7 @@ def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator # pull: X @ a (matvec) return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push) - def chain(self, outputs: Iterable["BaseDiscreteSolverOutput"], scale_by_marginals: bool = False) -> LinearOperator: + def chain(self, outputs: Iterable[BaseDiscreteSolverOutput], scale_by_marginals: bool = False) -> LinearOperator: """Chain subsequent applications of :attr:`transport_matrix`. Parameters @@ -206,7 +204,7 @@ def sparsify( batch_size: int = 1024, n_samples: Optional[int] = None, seed: Optional[int] = None, - ) -> "MatrixSolverOutput": + ) -> MatrixSolverOutput: """Sparsify the :attr:`transport_matrix`. This function sets all entries of the transport matrix below a certain threshold to :math:`0` and @@ -267,7 +265,7 @@ def sparsify( raise NotImplementedError(f"Mode `{mode}` is not yet implemented.") k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack) - tmaps_sparse: List[sp.csr_matrix] = [] + tmaps_sparse: list[sp.csr_matrix] = [] for batch in range(0, k, batch_size): x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) @@ -355,12 +353,12 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102 return self._transport_matrix @property - def shape(self) -> Tuple[int, int]: # noqa: D102 + def shape(self) -> tuple[int, int]: # noqa: D102 return self.transport_matrix.shape # type: ignore[return-value] def to( # noqa: D102 self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None - ) -> "BaseDiscreteSolverOutput": + ) -> BaseDiscreteSolverOutput: if device is not None: logger.warning(f"`{self!r}` does not support the `device` argument, ignoring.") if dtype is None: @@ -379,7 +377,7 @@ def converged(self) -> bool: # noqa: D102 return self._converged @property - def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102 + def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: # noqa: D102 return None @property diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index fe3908890..f9640d5f7 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -1049,18 +1049,14 @@ def __str__(self) -> str: class CondOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load """ - Base class for all optimal transport problems. + Base class for all conditional (nerual) optimal transport problems. Parameters ---------- adata Source annotated data object. kwargs - Keyword arguments for :class:`moscot.problems.base.BaseProblem.` - - Notes - ----- - If any of the source/target masks are specified, :attr:`adata_src`/:attr:`adata_tgt` will be a view. + Keyword arguments for :class:`moscot.base.problems.problem.BaseProblem` """ def __init__( @@ -1101,7 +1097,7 @@ def prepare( ---------- xy Geometry defining the linear term. If passed as a :class:`dict`, - :meth:`~moscot.solvers.TaggedArray.from_adata` will be called. + :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. policy Policy defining which pairs of distributions to sample from during training. policy_key @@ -1159,11 +1155,11 @@ def solve( Parameters ---------- backend - Which backend to use, see :func:`moscot.backends.get_available_backends`. + Which backend to use, see :func:`moscot.backends.utils.get_available_backends`. device - Device where to transfer the solution, see :meth:`moscot.solvers.BaseDiscreteSolverOutput.to`. + Device where to transfer the solution, see :meth:`moscot.base.output.BaseDiscreteSolverOutput.to`. kwargs - Keyword arguments for :meth:`moscot.solvers.BaseSolver.__call__`. + Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`. Returns diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index d58b48ac8..b78ad0cf2 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -18,7 +18,7 @@ from moscot._logging import logger from moscot._types import ArrayLike, Device_t, ProblemKind_t -from moscot.base.output import BaseDiscreteSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput, BaseSolverOutput from moscot.utils.tagged_array import Tag, TaggedArray __all__ = ["BaseSolver", "OTSolver"] @@ -151,7 +151,7 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: def _partition_kwargs(cls, **kwargs: Any) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Partition keyword arguments. - Used by the :meth:`~moscot.problems.base.BaseProblem.solve`. + Used by the :meth:`~moscot.base.problems.problem.BaseProblem.solve`. Parameters ---------- @@ -180,7 +180,7 @@ def __call__( tags: Mapping[Literal["x", "y", "xy"], Tag] = types.MappingProxyType({}), device: Optional[Device_t] = None, **kwargs: Any, - ) -> O: + ) -> BaseSolverOutput: """Solve an optimal transport problem. Parameters diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 9f4fbd649..929570d19 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -143,9 +143,9 @@ def push( Parameters ---------- source - Source key in :attr:`solutions`. + Source key in `solutions`. target - Target key in :attr:`solutions`. + Target key in `solutions`. data Initial data to push, see :meth:`~moscot.base.problems.OTProblem.push` for information. subset @@ -212,9 +212,9 @@ def pull( Parameters ---------- source - Source key in :attr:`solutions`. + Source key in `solutions`. target - Target key in :attr:`solutions`. + Target key in `solutions`. data Initial data to pull, see :meth:`~moscot.base.problems.OTProblem.pull` for information. subset diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 028bc5d65..dd488acf5 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -1,16 +1,6 @@ import enum from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - Hashable, - Literal, - Optional, - Tuple, - TypeVar, - Union, -) +from typing import Any, Callable, Hashable, Literal, Optional, Tuple, TypeVar, Union import numpy as np import scipy.sparse as sp @@ -325,21 +315,21 @@ def from_adata( b Marginals when used as target distribution. xy_attr - Attribute of :paramref:`adata` containing the data for the shared space. + Attribute of `adata` containing the data for the shared space. xy_key - Key of :paramref:`xy_attr` containing the data for the shared space. + Key of `xy_attr` containing the data for the shared space. xy_cost Cost function when in the shared space. xx_attr - Attribute of :paramref:`adata` containing the data for the incomparable space. + Attribute of `adata` containing the data for the incomparable space. xx_key - Key of :paramref:`xx_attr` containing the data for the incomparable space. + Key of `xx_attr` containing the data for the incomparable space. xx_cost Cost function in the incomparable space. conditions_attr - Attribute of :paramref:`adata` containing the conditions. + Attribute of `adata` containing the conditions. conditions_key - Key of :paramref:`conditions_attr` containing the conditions. + Key of `conditions_attr` containing the conditions. backend Backend to use. kwargs @@ -373,7 +363,7 @@ def from_adata( return cls(xy=xy_data, xx=xx_data, a=a, b=b, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn) -class DistributionCollection(Dict[K, DistributionContainer]): +class DistributionCollection(dict[K, DistributionContainer]): """Collection of distributions.""" def __repr__(self) -> str: