From b3f70deb2db771a3dac07b46da711611d1f7b138 Mon Sep 17 00:00:00 2001 From: michalk8 <46717574+michalk8@users.noreply.github.com> Date: Mon, 20 Mar 2023 18:14:13 +0100 Subject: [PATCH] Docs/nitpicky (#500) * Fix most links * Document __call__ * Fix last link * Enable docs linter in CI, disable spellcheck * Checkout submodules in CI * [ci skip] Fix typo --- .github/workflows/lint.yml | 5 +- .pre-commit-config.yaml | 4 +- docs/_templates/autosummary/class.rst | 7 +- docs/conf.py | 18 ++++- docs/developer.rst | 18 ++++- docs/notebooks | 2 +- pyproject.toml | 3 +- src/moscot/_docs/_docs.py | 35 +++----- src/moscot/_docs/_docs_mixins.py | 4 +- src/moscot/_docs/_docs_plot.py | 26 +++--- src/moscot/backends/__init__.py | 9 +-- src/moscot/backends/utils.py | 9 ++- src/moscot/base/cost.py | 13 +-- src/moscot/base/output.py | 10 +-- src/moscot/base/problems/_mixins.py | 4 +- src/moscot/base/problems/_utils.py | 2 +- src/moscot/base/problems/birth_death.py | 34 ++++---- src/moscot/base/problems/compound_problem.py | 46 +++++------ src/moscot/base/problems/problem.py | 29 +++---- src/moscot/base/solver.py | 30 ++++++- src/moscot/datasets.py | 3 +- src/moscot/problems/space/_alignment.py | 5 +- src/moscot/problems/space/_mixins.py | 8 +- .../spatiotemporal/_spatio_temporal.py | 13 ++- src/moscot/problems/time/_mixins.py | 6 +- src/moscot/utils/subset_policy.py | 81 +++++++++---------- src/moscot/utils/tagged_array.py | 15 ++-- .../time/test_temporal_base_problem.py | 6 +- 28 files changed, 240 insertions(+), 205 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d9a2cbf41..ff062ef09 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -19,11 +19,12 @@ jobs: strategy: fail-fast: false matrix: - # TODO(michalk8): enable in the future - lint-kind: [code] # , docs] + lint-kind: [code, docs] steps: - uses: actions/checkout@v3 + with: + submodules: true - name: Set up Python 3.10 uses: actions/setup-python@v4 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cefb6e97e..0cfe85e49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -52,7 +52,7 @@ repos: - id: blacken-docs additional_dependencies: [black==23.1.0] - repo: https://github.com/rstcheck/rstcheck - rev: v6.1.1 + rev: v6.1.2 hooks: - id: rstcheck additional_dependencies: [tomli] @@ -63,7 +63,7 @@ repos: - id: doc8 - repo: https://github.com/charliermarsh/ruff-pre-commit # Ruff version. - rev: v0.0.252 + rev: v0.0.257 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst index 5a4db15dd..1c7f7bc9a 100644 --- a/docs/_templates/autosummary/class.rst +++ b/docs/_templates/autosummary/class.rst @@ -10,7 +10,12 @@ .. autosummary:: :toctree: . {% for item in methods %} - {%- if item not in ['__init__', 'tree_flatten', 'tree_unflatten', 'bind'] %} + {%- if item not in ['__init__'] %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {%- for item in all_methods %} + {%- if item in ['__call__'] %} ~{{ name }}.{{ item }} {%- endif %} {%- endfor %} diff --git a/docs/conf.py b/docs/conf.py index 77e0fcc0a..3a93ed163 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,6 +59,18 @@ pygments_style = "tango" pygments_dark_style = "monokai" +nitpicky = True +nitpick_ignore = [ + ("py:class", "numpy.float64"), +] +# TODO(michalk8): remove once typing has been cleaned-up +nitpick_ignore_regex = [ + (r"py:class", r"moscot\..*(K|B|O)"), + (r"py:class", r"numpy\._typing.*"), + (r"py:class", r"moscot\..*Protocol.*"), +] + + # bibliography bibtex_bibfiles = ["references.bib"] bibtex_reference_style = "author_year" @@ -87,7 +99,6 @@ napoleon_google_docstring = False napoleon_numpy_docstring = True - # spelling spelling_lang = "en_US" spelling_warning = True @@ -101,6 +112,11 @@ "enchant.tokenize.MentionFilter", ] +linkcheck_ignore = [ + # 403 Client Error + r"https://doi.org/10.1126/science.aad0501", +] + exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"] # -- Options for HTML output ------------------------------------------------- diff --git a/docs/developer.rst b/docs/developer.rst index 5a6cef3e6..a8e95a7bb 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -1,9 +1,6 @@ Developer API ############# -.. module:: moscot - :noindex: - Backends ~~~~~~~~ .. module:: moscot.backends @@ -15,6 +12,9 @@ Backends backends.ott.GWSolver backends.ott.OTTOutput + backends.utils.get_solver + backends.utils.get_available_backends + Costs ~~~~~ .. module:: moscot.costs @@ -43,7 +43,7 @@ Problems problems.BirthDeathProblem problems.BaseCompoundProblem problems.CompoundProblem - problems.ProblemManager + cost.BaseCost Mixins ^^^^^^ @@ -55,12 +55,14 @@ Mixins Solvers ^^^^^^^ +.. module:: moscot.solvers .. currentmodule:: moscot.base .. autosummary:: :toctree: genapi solver.BaseSolver solver.OTSolver + output.BaseSolverOutput Output ^^^^^^ @@ -80,6 +82,7 @@ Policies .. autosummary:: :toctree: genapi + subset_policy.SubsetPolicy subset_policy.StarPolicy subset_policy.ExternalStarPolicy subset_policy.ExplicitPolicy @@ -97,3 +100,10 @@ Miscellaneous data.apoptosis_markers tagged_array.TaggedArray tagged_array.Tag + +.. currentmodule:: moscot.base.problems +.. autosummary:: + :toctree: genapi + + birth_death.beta + birth_death.delta diff --git a/docs/notebooks b/docs/notebooks index 218a8e36f..3184a2da7 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 218a8e36f429d20aa1d096f91f43f209c8032ff5 +Subproject commit 3184a2da75de1bd5e58925259437189ac33aa358 diff --git a/pyproject.toml b/pyproject.toml index 351186a42..4e3e803b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -289,7 +289,8 @@ set_env = SPHINXOPTS = -W -q --keep-going changedir = {tox_root}{/}docs commands = make linkcheck {posargs} - make spelling {posargs} + # TODO(michalk8): uncomment after https://github.com/theislab/moscot/issues/490 + # make spelling {posargs} [testenv:clean-docs] description = Remove the documentation. diff --git a/src/moscot/_docs/_docs.py b/src/moscot/_docs/_docs.py index badfe1289..62b61dcc0 100644 --- a/src/moscot/_docs/_docs.py +++ b/src/moscot/_docs/_docs.py @@ -9,11 +9,11 @@ """ _adata_x = """\ adata_x - Instance of :class:`anndata.AnnData` containing the data of the source distribution. + Instance of :class:`~anndata.AnnData` containing the data of the source distribution. """ _adata_y = """\ adata_y - Instance of :class:`anndata.AnnData` containing the data of the target distribution. + Instance of :class:`~anndata.AnnData` containing the data of the target distribution. """ _solver = """\ solver @@ -21,14 +21,14 @@ """ _source = """\ source - Value in :attr:`anndata.AnnData.obs` defining the assignment to the source distribution.""" + Value in :attr:`~anndata.AnnData.obs` defining the assignment to the source distribution.""" _target = """\ target - Value in :attr:`anndata.AnnData.obs` defining the assignment to the target distribution. + Value in :attr:`~anndata.AnnData.obs` defining the assignment to the target distribution. """ _reference = """\ reference - `reference` in :class:`moscot.problems._subset_policy.StarPolicy`. + `reference` in the :class:`~moscot.utils.subset_policy.StarPolicy`. """ _xy_callback = """\ xy_callback @@ -86,7 +86,7 @@ - If `data` is a :class:`str` this should correspond to a column in :attr:`anndata.AnnData.obs`. The transport map is applied to the subset corresponding to the source distribution (if `forward` is `True`) or target distribution (if `forward` is :obj:`False`) of that column. - - If `data` is a :class:npt.ArrayLike the transport map is applied to `data`. + - If `data` is a :class:`numpy.ndarray`, the transport map is applied to `data`. - If `data` is a :class:`dict` then the keys should correspond to the tuple defining a single optimal transport map and the value should be one of the two cases described above. """ @@ -96,23 +96,12 @@ """ _marginal_kwargs = r""" marginal_kwargs - Keyword arguments for :meth:`~moscot.problems.BirthDeathProblem._estimate_marginals`. If ``'scaling'`` + Keyword arguments for :meth:`~moscot.base.problems.BirthDeathProblem.estimate_marginals`. If ``'scaling'`` is in ``marginal_kwargs``, the left marginals are computed as :math:`\exp(\frac{(\textit{proliferation} - \textit{apoptosis}) \cdot (t_2 - t_1)}{\textit{scaling}})`. Otherwise, the left marginals are computed using a birth-death process. The keyword arguments - are either used for :func:`~moscot.problems.time._utils.beta`, i.e. one of: - - - beta_max: float - - beta_min: float - - beta_center: float - - beta_width: float - - or for :func:`~moscot.problems.time._utils.delta`, i.e. one of: - - - delta_max: float - - delta_min: float - - delta_center: float - - delta_width: float + are either used for :func:`~moscot.base.problems.birth_death.beta` or + :func:`~moscot.base.problems.birth_death.delta`. """ _shape = """\ shape @@ -140,10 +129,10 @@ a Specifies the left marginals. If - ``a`` is :class:`str` - the left marginals are taken from :attr:`anndata.AnnData.obs`, - - if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and + - if :meth:`score_genes_for_marginals` was run and if ``a`` is `None`, marginals are computed based on a birth-death process as suggested in :cite:`schiebinger:19`, - - if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run and + - if :meth:`score_genes_for_marginals` was run and if ``a`` is `None`, and additionally ``'scaling'`` is provided in ``marginal_kwargs``, the marginals are computed as :math:`\exp(\frac{(\textit{proliferation} - \textit{apoptosis}) \cdot (t_2 - t_1)}{\textit{scaling}})` @@ -154,7 +143,7 @@ b Specifies the right marginals. If - ``b`` is :class:`str` - the left marginals are taken from :attr:`anndata.AnnData.obs`, - - if :meth:`~moscot.problems.base._birth_death.BirthDeathMixin.score_genes_for_marginals` was run + - if :meth:`score_genes_for_marginals` was run uniform (mean of left marginals) right marginals are used, - otherwise or if ``b`` is :obj:`False`, uniform marginals are used. """ diff --git a/src/moscot/_docs/_docs_mixins.py b/src/moscot/_docs/_docs_mixins.py index 171b04703..9cdf92161 100644 --- a/src/moscot/_docs/_docs_mixins.py +++ b/src/moscot/_docs/_docs_mixins.py @@ -57,7 +57,7 @@ """ _return_cell_transition = "Transition matrix of cells or groups of cells." _notes_cell_transition = """\ -To visualise the results, see :func:`moscot.pl.cell_transition`. +To visualize the results, see :func:`moscot.plotting.cell_transition`. """ _normalize = """\ normalize @@ -82,7 +82,7 @@ - If `data` is a :class:`str` this should correspond to a column in :attr:`anndata.AnnData.obs`. The transport map is applied to the subset corresponding to the source distribution (if `forward` is `True`) or target distribution (if `forward` is `False`) of that column. - - If `data` is a :class:npt.ArrayLike the transport map is applied to `data`. + - If `data` is a :class:`numpy.ndarray` the transport map is applied to `data`. - If `data` is a :class:`dict` then the keys should correspond to the tuple defining a single optimal transport map and the value should be one of the two cases described above. """ diff --git a/src/moscot/_docs/_docs_plot.py b/src/moscot/_docs/_docs_plot.py index cdeb567ba..ddc3da75f 100644 --- a/src/moscot/_docs/_docs_plot.py +++ b/src/moscot/_docs/_docs_plot.py @@ -18,15 +18,15 @@ """ _cbar_kwargs_cell_transition = """\ cbar_kwargs - Keyword arguments for :func:`matplotlib.figure.Figure.colorbar`.""" + Keyword arguments for :meth:`~matplotlib.figure.Figure.colorbar`.""" # return cell transition _return_cell_transition = """\ -:class:`matplotlib.figure.Figure` heatmap of cell transition matrix. +Heatmap of cell transition matrix. """ # notes cell transition _notes_cell_transition = """\ -This function looks for the following data in the :class:`anndata.AnnData` object -which is passed or saved as an attribute of :mod:`moscot.problems.base.CompoundProblem`. +This function looks for the following data in the :class:`~anndata.AnnData` object +which is passed or saved as an attribute of :class:`moscot.base.problems.CompoundProblem`. - `transition_matrix` - `source` @@ -57,8 +57,8 @@ """ # notes sankey _notes_sankey = """\ -This function looks for the following data in the :class:`anndata.AnnData` object -which is passed or saved as an attribute of :mod:`moscot.problems.base.CompoundProblem`. +This function looks for the following data in the :class:`~anndata.AnnData` object +which is passed or saved as an attribute of :class:`moscot.base.problems.CompoundProblem`. - `transition_matrices` - `captions` @@ -66,7 +66,7 @@ """ _alpha_transparency = """\ alpha - Transparancy value. + Transparency value. """ _interpolate_color = """\ interpolate_color @@ -102,12 +102,12 @@ """ # return push/pull _return_push_pull = """\ -:class:`matplotlib.figure.Figure` scatterplot in `basis` coordinates. +Scatterplot in `basis` coordinates. """ # notes push/pull _notes_push_pull = """\ -This function looks for the following data in the :class:`anndata.AnnData` object -which is passed or saved as an attribute of :mod:`moscot.problems.base.CompoundProblem`. +This function looks for the following data in the :class:`~anndata.AnnData` object +which is passed or saved as an attribute of :class:`moscot.base.problems.CompoundProblem`. - `temporal_key` """ @@ -116,7 +116,7 @@ # general input _input_plotting = """\ inp - An instance of :class:`anndata.AnnData` where the results of the corresponding method + An instance of :class:`~anndata.AnnData` where the results of the corresponding method of the :mod:`moscot.problems` instance is saved. Alternatively, the instance of the moscot problem can be passed, too.""" _uns_key = """\ @@ -125,7 +125,7 @@ of the moscot problem instance is saved.""" _cmap = """\ cmap - Colormap for continuous annotations, see :class:`matplotlib.colors.Colormap`.""" + Colormap for continuous annotations, see :class:`~matplotlib.colors.Colormap`.""" _title = """\ title Title of the plot. @@ -149,7 +149,7 @@ Whether to return the figure.""" _ax = f"""\ ax - Axes, :class:`matplotlib.axes.Axes`. + Axes. {_return_fig}""" _figsize_dpi_save = f"""\ figsize diff --git a/src/moscot/backends/__init__.py b/src/moscot/backends/__init__.py index d68f51f80..852fed392 100644 --- a/src/moscot/backends/__init__.py +++ b/src/moscot/backends/__init__.py @@ -1,11 +1,4 @@ -from typing import Tuple - from moscot.backends import ott -from moscot.backends.utils import get_solver, register_solver +from moscot.backends.utils import get_available_backends, get_solver, register_solver __all__ = ["ott", "get_solver", "register_solver", "get_available_backends"] - - -def get_available_backends() -> Tuple[str, ...]: - """TODO.""" - return ("ott",) diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index 6e18048e4..100a1dc71 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Literal, Union +from typing import TYPE_CHECKING, Any, Literal, Tuple, Union from moscot import _registry from moscot._types import ProblemKind_t @@ -6,7 +6,7 @@ if TYPE_CHECKING: from moscot.backends import ott -__all__ = ["get_solver", "register_solver"] +__all__ = ["get_solver", "register_solver", "get_available_backends"] _REGISTRY = _registry.Registry() @@ -33,3 +33,8 @@ def _(problem_kind: Literal["linear", "quadratic"], **kwargs: Any) -> Union["ott if problem_kind == "quadratic": return ott.GWSolver(**kwargs) raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}` problem.") + + +def get_available_backends() -> Tuple[str, ...]: + """TODO.""" + return tuple(backend for backend in _REGISTRY) diff --git a/src/moscot/base/cost.py b/src/moscot/base/cost.py index 7ddea772e..d5e7a9cfa 100644 --- a/src/moscot/base/cost.py +++ b/src/moscot/base/cost.py @@ -12,18 +12,18 @@ class BaseCost(ABC): - """Base class for all :mod:`moscot` losses. + """Base class for all :mod:`moscot.costs`. Parameters ---------- adata Annotated data object. attr - Attribute of :class:`anndata.AnnData` used when computing the cost. + Attribute of :class:`~anndata.AnnData` used when computing the cost. key - Key in the ``attr`` of :class:`anndata.AnnData` used when computing the cost. + Key in the ``attr`` of :class:`~anndata.AnnData` used when computing the cost. dist_key - Helper key which determines which distribution ``adata`` belongs to. + Helper key which determines which distribution :attr:`adata` belongs to. """ def __init__(self, adata: AnnData, attr: str, key: str, dist_key: Union[Any, Tuple[Any, Any]]): @@ -42,9 +42,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike: Parameters ---------- args - Positional arguments for :meth:`_compute`. + Positional arguments for computation. kwargs - Keyword arguments for :meth:`_compute`. + Keyword arguments for computation. Returns ------- @@ -64,6 +64,7 @@ def adata(self) -> AnnData: """Annotated data object.""" return self._adata + # TODO(michalk8): don't require impl. @property @abstractmethod def data(self) -> Any: diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 0d1e3c22e..ff064a6a1 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -44,7 +44,7 @@ def converged(self) -> bool: def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: """Dual potentials :math:`f` and :math:`g`. - Only valid for the Sinkhorn's algorithm. + Only valid for the Sinkhorn algorithm. """ @property @@ -54,13 +54,12 @@ def is_linear(self) -> bool: @abstractmethod def to(self, device: Optional[Device_t] = None) -> "BaseSolverOutput": - """Transfer self to another device using :func:`jax.device_put`. + """Transfer self to another device. Parameters ---------- device - Device where to transfer the solver output. - If `None`, use the default device. + Device where to transfer the solver output. If `None`, use the default device. Returns ------- @@ -211,7 +210,7 @@ def __str__(self) -> str: class MatrixSolverOutput(BaseSolverOutput): - """Optimal transport output with materialized :attr:`transport_matrix`. + """Optimal transport output with materialized transport matrix. Parameters ---------- @@ -225,6 +224,7 @@ class MatrixSolverOutput(BaseSolverOutput): TODO. """ + # TODO(michalk8): don't provide defaults? def __init__( self, transport_matrix: ArrayLike, *, cost: float = np.nan, converged: bool = True, is_linear: bool = True ): diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 79a9f661e..fac319b4a 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -534,7 +534,7 @@ def compute_feature_correlation( - If of type :obj:`list`, the elements should be from :attr:`anndata.AnnData.var_names` or :attr:`anndata.AnnData.obs_names`. - If `human`, `mouse`, or `drosophila`, the features are subsetted to transcription factors, - see :class:`moscot.utils._data.TranscriptionFactors`. + see :func:`~moscot.utils.data.transcription_factors`. confidence_level Confidence level for the confidence interval calculation. Must be in interval `[0, 1]`. @@ -543,7 +543,7 @@ def compute_feature_correlation( seed Random seed when ``method = perm_test``. kwargs - Keyword arguments for :func:`moscot._utils.parallelize`, e.g. `n_jobs`. + Keyword arguments for parallelization, e.g., `n_jobs`. # TODO(michalk8): consider making the function public Returns ------- diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index c2972cab8..035dccad3 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -233,7 +233,7 @@ def _correlation_test( seed Random seed if ``method = 'perm_test'``. kwargs - Keyword arguments for :func:`moscot._utils.parallelize`, e.g. `n_jobs`. + Keyword arguments for parallelization, e.g., `n_jobs`. Returns ------- diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index bfd5a4478..07a74e3a2 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -53,7 +53,7 @@ class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101 class BirthDeathMixin: - """Mixin class for biological problems based on :class:`moscot.problems.mixins.BirthDeathProblem`.""" + """Mixin for biological problems based on :class:`~moscot.base.problems.BirthDeathProblem`.""" def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) @@ -70,16 +70,15 @@ def score_genes_for_marginals( proliferation_key: str = "proliferation", apoptosis_key: str = "apoptosis", **kwargs: Any, - ) -> "BirthDeathProtocol": - """ - Compute gene scores to obtain prior knowledge about proliferation and apoptosis. + ) -> "BirthDeathMixin": + """Compute gene scores to obtain prior knowledge about proliferation and apoptosis. - This method computes gene scores using :func:`scanpy.tl.score_genes`. Therefore, a list of genes corresponding + This method computes gene scores using :func:`~scanpy.tl.score_genes`. Therefore, a list of genes corresponding to proliferation and/or apoptosis must be passed. Alternatively, proliferation and apoptosis genes for humans and mice are saved in :mod:`moscot`. - The gene scores will be used in :meth:`~moscot.problems.CompoundBaseProblem.prepare` to estimate the initial - growth rates as suggested in :cite:`schiebinger:19` + The gene scores will be used in :meth:`~moscot.base.problems.BaseCompoundProblem.prepare` to estimate + the initial growth rates as suggested in :cite:`schiebinger:19` Parameters ---------- @@ -91,15 +90,17 @@ def score_genes_for_marginals( to be used the corresponding organism must be passed. proliferation_key Key in :attr:`anndata.AnnData.obs` where to add the genes scores. + apoptosis_key + TODO. kwargs - Keyword arguments for :func:`scanpy.tl.score_genes`. + Keyword arguments for :func:`~scanpy.tl.score_genes`. Returns ------- - Returns :class:`moscot.problems.time.TemporalProblem` and updates the following attributes + Returns self and updates the following: - - :attr:`proliferation_key` - - :attr:`apoptosis_key` + - :attr:`proliferation_key` - TODO: description + - :attr:`apoptosis_key` - TODO: description Notes ----- @@ -144,11 +145,11 @@ def score_genes_for_marginals( "At least one of `gene_set_proliferation` or `gene_set_apoptosis` must be provided to score genes." ) - return self + return self # type: ignore[return-value] @property def proliferation_key(self) -> Optional[str]: - """Key in :attr:`anndata.AnnData.obs` where cell proliferation is stored.""" + """Key in :attr:`~anndata.AnnData.obs` where cell proliferation is stored.""" return self._proliferation_key @proliferation_key.setter @@ -171,15 +172,14 @@ def apoptosis_key(self: BirthDeathProtocol, key: Optional[str]) -> None: @d.dedent class BirthDeathProblem(BirthDeathMixin, OTProblem): - """ - Class handling an optimal transport problem which allows to estimate the marginals with a birth-death process. + """Optimal transport problem which allows to estimate the marginals with a birth-death process. Parameters ---------- %(adata_x)s """ - def _estimate_marginals( + def estimate_marginals( self: BirthDeathProblemProtocol, adata: AnnData, source: bool, @@ -188,6 +188,8 @@ def _estimate_marginals( marginal_kwargs: Mapping[str, Any] = MappingProxyType({}), **_: Any, ) -> ArrayLike: + """TODO.""" + def estimate(key: Optional[str], *, fn: Callable[..., ArrayLike], **kwargs: Any) -> ArrayLike: if key is None: return np.zeros(adata.n_obs, dtype=float) diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 39065cbbd..a3bd9b8e5 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -39,6 +39,7 @@ OrderedPolicy, StarPolicy, SubsetPolicy, + create_policy, ) from moscot.utils.tagged_array import Tag, TaggedArray @@ -69,7 +70,7 @@ class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]): Raises ------ TypeError - If `base_problem_type` is not a subclass of :class:`moscot.problems.OTProblem`. + If `base_problem_type` is not a subclass of :class:`~moscot.base.problems.OTProblem`. """ def __init__(self, adata: AnnData, **kwargs: Any): @@ -187,7 +188,7 @@ def prepare( x_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), y_callback_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, - ) -> "BaseCompoundProblem[K,B]": + ) -> "BaseCompoundProblem[K, B]": """ Prepare the biological problem. @@ -208,7 +209,7 @@ def prepare( Returns ------- - :class:`moscot.problems.CompoundProblem`. + The prepared problem. """ self._ensure_valid_policy(policy) policy = self._create_policy(policy=policy, key=key) @@ -254,14 +255,11 @@ def solve( stage Some stage TODO. kwargs - Keyword arguments for one of: - - :meth:`moscot.problems.OTProblem.solve`. - - :meth:`moscot.problems.MultiMarginalProblem.solve`. - - :meth:`moscot.problems.BirthDeathProblem.solve`. + Keyword arguments for :meth:`~moscot.base.problems.OTProblem.solve`. Returns ------- - :class:`moscot.problems.CompoundProblem`. + The solver problem. """ if TYPE_CHECKING: assert isinstance(self._problem_manager, ProblemManager) @@ -369,7 +367,7 @@ def push(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: %(scale_by_marginals)s kwargs - keyword arguments for policy-specific `_apply` method of :class:`moscot.problems.CompoundProblem`. + keyword arguments for policy-specific `_apply` method of :class:`moscot.base.problems.CompoundProblem`. Returns ------- @@ -400,7 +398,7 @@ def pull(self, *args: Any, **kwargs: Any) -> ApplyOutput_t[K]: %(scale_by_marginals)s kwargs - keyword arguments for policy-specific `_apply` method of :class:`moscot.problems.CompoundProblem`. + keyword arguments for policy-specific `_apply` method of :class:`moscot.base.problems.CompoundProblem`. Returns ------- @@ -427,24 +425,23 @@ def add_problem( overwrite: bool = False, **kwargs: Any, ) -> "BaseCompoundProblem[K, B]": - """ - Add a problem. + """Add a subproblem. This function adds and prepares a problem, e.g. if it is not included by the initial - :class:`moscot.problems._subset_policy.SubsetPolicy`. + :class:`~moscot.utils.subset_policy.SubsetPolicy`. Parameters ---------- %(key)s problem - Instance of :class:`moscot.problems.base.OTProblem`. + Instance of a :class:`~moscot.base.problems.OTProblem`. overwrite If `True` the problem will be reinitialized and prepared even if a problem with `key` exists. Returns ------- - :class:`moscot.problems.base.BaseCompoundProblem`. + The updated compound problem. """ if TYPE_CHECKING: assert isinstance(self._problem_manager, ProblemManager) @@ -454,8 +451,7 @@ def add_problem( @d.dedent @require_prepare def remove_problem(self, key: Tuple[K, K]) -> "BaseCompoundProblem[K, B]": - """ - Remove a (sub)problem. + """Remove a subproblem. Parameters ---------- @@ -463,7 +459,7 @@ def remove_problem(self, key: Tuple[K, K]) -> "BaseCompoundProblem[K, B]": Returns ------- - :class:`moscot.problems.base.BaseCompoundProblem` + The updated compound problem. """ if TYPE_CHECKING: assert isinstance(self._problem_manager, ProblemManager) @@ -489,11 +485,11 @@ def save( file_prefix Prefix to prepend to the file name. overwrite - Overwrite existing data or not. + Whether to overwrite existing data or not. Returns ------- - None + Nothing, just saves the problem. """ file_name = ( f"{file_prefix}_{self.__class__.__name__}.pkl" @@ -525,10 +521,11 @@ def load( Returns ------- - Loaded instance of the model. + Loaded instance of the model. Examples -------- + #TODO(michalk8): make nicer >>> problem = ProblemClass.load(filename) # use the name of the model class used to save >>> problem.push.... """ @@ -582,10 +579,9 @@ def __str__(self) -> str: @d.get_sections(base="CompoundProblem", sections=["Parameters", "Raises"]) @d.dedent class CompoundProblem(BaseCompoundProblem[K, B], abc.ABC): - """ - Class handling biological problems composed of exactly one :class:`anndata.AnnData` instance. + """Class handling biological problems composed of exactly one :class:`~anndata.AnnData` instance. - This class is needed to apply the `policy` to one :class:`anndata.AnnData` objects and hence create the + This class is needed to apply the `policy` to one :class:`~anndata.AnnData` objects and hence create the Optimal Transport subproblems from the biological problem. Parameters @@ -610,7 +606,7 @@ def _create_policy( **_: Any, ) -> SubsetPolicy[K]: if isinstance(policy, str): - return SubsetPolicy.create(policy, adata=self.adata, key=key) + return create_policy(policy, adata=self.adata, key=key) return ExplicitPolicy(self.adata, key=key) def _callback_handler( diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index f67ed892c..5c1dd0f2d 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -132,7 +132,7 @@ class OTProblem(BaseProblem): adata Source annotated data object. adata_tgt - Target annotated data object. If `None`, use ``adata``. + Target annotated data object. If :obj:`None`, use ``adata``. src_obs_mask Source observation mask that defines :attr:`adata_src`. tgt_obs_mask @@ -142,11 +142,11 @@ class OTProblem(BaseProblem): tgt_var_mask Target variable mask that defines :attr:`adata_tgt`. src_key - Source key name, usually supplied by :class:`moscot.problems.CompoundBaseProblem`. + Source key name, usually supplied by :class:`~moscot.base.problems.BaseCompoundProblem`. tgt_key - Target key name, usually supplied by :class:`moscot.problems.CompoundBaseProblem`. + Target key name, usually supplied by :class:`~moscot.base.problems.BaseCompoundProblem`. kwargs - Keyword arguments for :class:`moscot.problems.base.BaseProblem.` + Keyword arguments for :class:`~moscot.base.problems.BaseProblem`. Notes ----- @@ -232,13 +232,13 @@ def prepare( ---------- xy Geometry defining the linear term. If passed as a :class:`dict`, - :meth:`~moscot.solvers.TaggedArray.from_adata` will be called. + :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. x First geometry defining the quadratic term. If passed as a :class:`dict`, - :meth:`~moscot.solvers.TaggedArray.from_adata` will be called. + :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. y Second geometry defining the quadratic term. If passed as a :class:`dict`, - :meth:`~moscot.solvers.TaggedArray.from_adata` will be called. + :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. a Source marginals. Valid value are: @@ -315,11 +315,11 @@ def solve( Parameters ---------- backend - Which backend to use, see :func:`moscot.backends.get_available_backends`. + Which backend to use, see :func:`~moscot.backends.utils.get_available_backends`. device - Device where to transfer the solution, see :meth:`moscot.solvers.BaseSolverOutput.to`. + Device where to transfer the solution, see :meth:`moscot.base.output.BaseSolverOutput.to`. kwargs - Keyword arguments for :meth:`moscot.solvers.BaseSolver.__call__`. + Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`. Returns ------- @@ -353,7 +353,7 @@ def push( split_mass: bool = False, **kwargs: Any, ) -> ArrayLike: - """Push mass through the :attr:`~moscot.solvers.BaseSolverOutput.transport_matrix`. + """Push mass through the :attr:`~moscot.base.output.BaseSolverOutput.transport_matrix`. Parameters ---------- @@ -393,7 +393,7 @@ def pull( split_mass: bool = False, **kwargs: Any, ) -> ArrayLike: - """Pull mass through the :attr:`~moscot.solvers.BaseSolverOutput.transport_matrix`. + """Pull mass through the :attr:`~moscot.base.output.BaseSolverOutput.transport_matrix`. Parameters ---------- @@ -477,7 +477,7 @@ def _create_marginals( self, adata: AnnData, *, source: bool, data: Optional[Union[bool, str, ArrayLike]] = None, **kwargs: Any ) -> ArrayLike: if data is True: - marginals = self._estimate_marginals(adata, source=source, **kwargs) + marginals = self.estimate_marginals(adata, source=source, **kwargs) elif data in (False, None): marginals = np.ones((adata.n_obs,), dtype=float) / adata.n_obs elif isinstance(data, str): @@ -495,7 +495,8 @@ def _create_marginals( ) return marginals - def _estimate_marginals(self, adata: AnnData, *, source: bool, **kwargs: Any) -> ArrayLike: + def estimate_marginals(self, adata: AnnData, *, source: bool, **kwargs: Any) -> ArrayLike: + """TODO.""" return np.ones((adata.n_obs,), dtype=float) / adata.n_obs @d.dedent diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index a100797c4..d0a28a016 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -89,11 +89,33 @@ class BaseSolver(Generic[O], abc.ABC): @abc.abstractmethod def _prepare(self, **kwargs: Any) -> Any: - pass + """Prepare a problem. + + Parameters + ---------- + kwargs + Keyword arguments. + + Returns + ------- + Object passed to :meth:`_solve`. + """ @abc.abstractmethod def _solve(self, data: Any, **kwargs: Any) -> O: - pass + """Solve a problem. + + Parameters + ---------- + data + Object returned by :meth:`_prepare`. + kwargs + Additional keyword arguments. + + Returns + ------- + The output. + """ @property @abc.abstractmethod @@ -106,7 +128,7 @@ def __call__(self, **kwargs: Any) -> O: Parameters ---------- kwargs - Keyword arguments for :meth:`_prepare`. + Keyword arguments for data preparation. Returns ------- @@ -143,7 +165,7 @@ def __call__( tags How to interpret the data in ``xy``, ``x`` and ``y``. device - Device to transfer the output to, see :meth:`moscot.solvers.BaseSolverOutput.to`. + Device to transfer the output to, see :meth:`~moscot.base.output.BaseSolverOutput.to`. kwargs Keyword arguments for parent's :meth:`__call__`. diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 69783ed4b..0ab1eb7f6 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -79,8 +79,7 @@ def drosophila( ) -> AnnData: """Embryo of Drosophila melanogaster described in :cite:`Li-spatial:22`. - Minimal pre-processing was performed, such as gene and cell filtering, as well as normalization, - see the `processing steps `_. + Minimal pre-processing was performed, such as gene and cell filtering, as well as normalization. Parameters ---------- diff --git a/src/moscot/problems/space/_alignment.py b/src/moscot/problems/space/_alignment.py index df43eb918..d2f57d5c5 100644 --- a/src/moscot/problems/space/_alignment.py +++ b/src/moscot/problems/space/_alignment.py @@ -40,8 +40,7 @@ def prepare( b: Optional[str] = None, **kwargs: Any, ) -> "AlignmentProblem[K, B]": - """ - Prepare the :class:`moscot.problems.space.AlignmentProblem`. + """Prepare the problem. This method prepares the data to be passed to the optimal transport solver. @@ -54,7 +53,7 @@ def prepare( reference Only used if `policy="star"`, it's the value for reference stored - in :attr:`adata.obs` ``["batch_key"]``. + in :attr:`anndata.AnnData.obs` ``["batch_key"]``. %(cost)s %(a)s diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 25cc943a0..bece18ddd 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -387,16 +387,16 @@ def impute( # type: ignore[misc] var_names: Optional[Sequence[Any]] = None, device: Optional[Device_t] = None, ) -> AnnData: - """ - Impute expression of specific genes. + """Impute expression of specific genes. Parameters ---------- - TODO: don't use device from docstrings here, as different use + var_names: + TODO: don't use device from docstrings here, as different use Returns ------- - :class:`anndata.AnnData` with imputed gene expression values. + Annotated data object with imputed gene expression values. """ if var_names is None: var_names = self.adata_sc.var_names diff --git a/src/moscot/problems/spatiotemporal/_spatio_temporal.py b/src/moscot/problems/spatiotemporal/_spatio_temporal.py index 984883171..3b8c46050 100644 --- a/src/moscot/problems/spatiotemporal/_spatio_temporal.py +++ b/src/moscot/problems/spatiotemporal/_spatio_temporal.py @@ -38,6 +38,7 @@ class SpatioTemporalProblem( %(adata)s """ + # TODO(michalk8): check if this is necessary def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) @@ -57,11 +58,10 @@ def prepare( marginal_kwargs: Mapping[str, Any] = MappingProxyType({}), **kwargs: Any, ) -> "SpatioTemporalProblem": - """ - Prepare the :class:`moscot.problems.spatio_temporal.SpatioTemporalProblem`. + """Prepare the problem. This method executes multiple steps to prepare the problem for the Optimal Transport solver to be ready - to solve it + to solve it. Parameters ---------- @@ -76,7 +76,7 @@ def prepare( Returns ------- - :class:`moscot.problems.spatio_temporal.SpatioTemporalProblem`. + The prepared problem. Notes ----- @@ -136,8 +136,7 @@ def solve( device: Optional[Literal["cpu", "gpu", "tpu"]] = None, **kwargs: Any, ) -> "SpatioTemporalProblem": - """ - Solve optimal transport problems defined in :class:`moscot.problems.space.SpatioTemporalProblem`. + """Solve the problem. Parameters ---------- @@ -160,7 +159,7 @@ def solve( Returns ------- - :class:`moscot.problems.space.SpatioTemporalProblem`. + The solved problem. Examples -------- diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 62060b284..4656bcb12 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -248,7 +248,7 @@ def sankey( Notes ----- - To visualise the results, see :func:`moscot.pl.sankey`. + To visualise the results, see :func:`moscot.plotting.sankey`. """ tuples = self._policy.plan(start=source, end=target) cell_transitions = [] @@ -597,9 +597,7 @@ def compute_interpolated_distance( Therefore, the Wasserstein distance between the interpolated and the real distribution is computed. It is recommended to compare the Wasserstein distance to the ones obtained by - :meth:`compute_time_point_distances`, - :meth:`compute_random_distance`, and - :meth:`compute_time_point_distance`. + :meth:`compute_time_point_distances` and :meth:`compute_random_distance`. This method does not instantiate the transport matrix if the solver output does not. diff --git a/src/moscot/utils/subset_policy.py b/src/moscot/utils/subset_policy.py index 307ae25bf..8f90b00c2 100644 --- a/src/moscot/utils/subset_policy.py +++ b/src/moscot/utils/subset_policy.py @@ -37,6 +37,7 @@ "ExplicitPolicy", "DummyPolicy", "FormatterMixin", + "create_policy", ] @@ -126,47 +127,6 @@ def _filter_plan( ) -> Sequence[Tuple[K, K]]: return [step for step in plan if step in filter] - @classmethod - def create( - cls, - kind: Policy_t, - adata: Union[AnnData, pd.Series, pd.Categorical], - **kwargs: Any, - ) -> "SubsetPolicy[K]": - """ - Create an instance of a `moscot.utils.SubsetPolicy`. - - Parameters - ---------- - %(policy_kind)s - %(adata)s - - kwargs - Keyword arguments for a :class:`moscot.utils.SubsetPolicy`. - - Returns - ------- - An instance of a :class:`moscot.utils.SubsetPolicy`. - - Notes - ----- - TODO: link policy example. - """ - if kind == _constants.SEQUENTIAL: - return SequentialPolicy(adata, **kwargs) - if kind == _constants.STAR: - return StarPolicy(adata, **kwargs) - if kind == _constants.EXTERNAL_STAR: - return ExternalStarPolicy(adata, **kwargs) - if kind == _constants.TRIU: - return TriangularPolicy(adata, **kwargs, upper=True) - if kind == _constants.TRIL: - return TriangularPolicy(adata, **kwargs, upper=False) - if kind == _constants.EXPLICIT: - return ExplicitPolicy(adata, **kwargs) - - raise NotImplementedError(kind) - def create_mask(self, value: Union[K, Sequence[K]], *, allow_empty: bool = False) -> ArrayLike: if isinstance(value, str) or not isinstance(value, Iterable): mask = self._data == value @@ -405,3 +365,42 @@ def _filter_plan( self, plan: Sequence[Tuple[K, K]], filter: Sequence[Tuple[K, K]] # noqa: A002 ) -> Sequence[Tuple[K, K]]: return plan + + +# TODO(michalk8): in the future, use Registry +def create_policy( + kind: Policy_t, + adata: Union[AnnData, pd.Series, pd.Categorical], + **kwargs: Any, +) -> SubsetPolicy[K]: + """Create a policy. + + Parameters + ---------- + %(policy_kind)s + %(adata)s + kwargs + Keyword arguments for a :class:`~moscot.utils.subset_policy.SubsetPolicy`. + + Returns + ------- + The policy. + + Notes + ----- + TODO: link policy example. + """ + if kind == _constants.SEQUENTIAL: + return SequentialPolicy(adata, **kwargs) + if kind == _constants.STAR: + return StarPolicy(adata, **kwargs) + if kind == _constants.EXTERNAL_STAR: + return ExternalStarPolicy(adata, **kwargs) + if kind == _constants.TRIU: + return TriangularPolicy(adata, **kwargs, upper=True) + if kind == _constants.TRIL: + return TriangularPolicy(adata, **kwargs, upper=False) + if kind == _constants.EXPLICIT: + return ExplicitPolicy(adata, **kwargs) + + raise NotImplementedError(kind) diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 9a7306e01..c2413c3f4 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -16,9 +16,8 @@ @enum.unique class Tag(str, enum.Enum): - """Tag used to interpret array-like data in :class:`moscot.solvers.TaggedArray`.""" + """Tag used to interpret array-like data in a class:`TaggedArray`.""" - # TODO(michalk8): document rest of the classes COST_MATRIX = "cost_matrix" #: Cost matrix. KERNEL = "kernel" #: Kernel matrix. POINT_CLOUD = "point_cloud" #: Point cloud. @@ -75,7 +74,7 @@ def from_adata( backend: Literal["ott"] = "ott", **kwargs: Any, ) -> "TaggedArray": - """Create tagged array from :class:`anndata.AnnData`. + """Create tagged array from :class:`~anndata.AnnData`. Parameters ---------- @@ -84,22 +83,22 @@ def from_adata( dist_key Helper key which determines into which subset ``adata`` belongs. attr - Attribute of :class:`anndata.AnnData` used when extracting/computing the cost. + Attribute of :class:`~anndata.AnnData` used when extracting/computing the cost. tag Tag used to interpret the extracted data. key - Key in the ``attr`` of :class:`anndata.AnnData` used when extracting/computing the cost. + Key in the ``attr`` of :class:`~anndata.AnnData` used when extracting/computing the cost. cost Cost function to apply to the extracted array, depending on ``tag``: - if ``tag = 'point_cloud'``, it is extracted from the ``backend``. - if ``tag = 'cost'`` or ``tag = 'kernel'``, and ``cost = 'custom'``, the extracted array is already assumed to be a cost/kernel matrix. - Otherwise, :class:`moscot.costs.BaseCost` is used to compute the cost matrix. + Otherwise, :class:`~moscot.base.cost.BaseCost` is used to compute the cost matrix. backend - Which backend to use, see :func:`moscot.backends.get_available_backends`. + Which backend to use, see :func:`~moscot.backends.utils.get_available_backends`. kwargs - Keyword arguments for :class:`moscot.costs.BaseCost`. + Keyword arguments for :class:`~moscot.base.cost.BaseCost`. Returns ------- diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index 5ca062123..1a2ace5b8 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -56,16 +56,16 @@ def test_estimate_marginals_pipeline( if proliferation_key is not None and "error" in proliferation_key: with pytest.raises(KeyError, match=r"Unable to find proliferation"): - _ = prob._estimate_marginals( + _ = prob.estimate_marginals( adata, source=source, proliferation_key=proliferation_key, apoptosis_key=apoptosis_key ) elif proliferation_key is None and apoptosis_key is None: with pytest.raises(ValueError, match=r"Either `proliferation_key` or `apoptosis_key`"): - _ = prob._estimate_marginals( + _ = prob.estimate_marginals( adata, source=source, proliferation_key=proliferation_key, apoptosis_key=apoptosis_key ) else: - a_estimated = prob._estimate_marginals( + a_estimated = prob.estimate_marginals( adata, source=source, proliferation_key=proliferation_key, apoptosis_key=apoptosis_key ) assert isinstance(a_estimated, np.ndarray)