Skip to content

Commit

Permalink
(chore): fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed May 24, 2024
1 parent ea3dc93 commit b9b66d5
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 66 deletions.
12 changes: 12 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,24 @@
nitpicky = True
nitpick_ignore = [
("py:class", "numpy.float64"),
# see: https://github.com/numpy/numpydoc/issues/275
("py:class", "None. Remove all items from D."),
("py:class", "a set-like object providing a view on D's items"),
("py:class", "a set-like object providing a view on D's keys"),
("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501
("py:class", "None. Update D from dict/iterable E and F."),
("py:class", "an object providing a view on D's values"),
("py:class", "a shallow copy of D"),
]
# TODO(michalk8): remove once typing has been cleaned-up
nitpick_ignore_regex = [
(r"py:class", r"moscot\..*(K|B|O)"),
(r"py:class", r"numpy\._typing.*"),
(r"py:class", r"moscot\..*Protocol.*"),
(
r"py:class",
r"moscot.base.output.BaseSolverOutput",
), # https://github.com/sphinx-doc/sphinx/issues/10974 means there is simply no way around this with generics
]


Expand Down
6 changes: 5 additions & 1 deletion docs/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ Backends
backends.ott.GWSolver
backends.ott.OTTOutput
backends.ott.GraphOTTOutput
backends.ott.GENOTLinSolver
backends.ott.output.OTTNeuralOutput
backends.utils.get_solver
backends.utils.get_available_backends

Expand Down Expand Up @@ -44,6 +46,7 @@ Problems
problems.BaseCompoundProblem
problems.CompoundProblem
cost.BaseCost
problems.CondOTProblem

Mixins
^^^^^^
Expand All @@ -62,7 +65,6 @@ Solvers

solver.BaseSolver
solver.OTSolver
output.BaseDiscreteSolverOutput

Output
^^^^^^
Expand Down Expand Up @@ -100,6 +102,8 @@ Miscellaneous
data.apoptosis_markers
tagged_array.TaggedArray
tagged_array.Tag
tagged_array.DistributionCollection
tagged_array.DistributionContainer

.. currentmodule:: moscot.base.problems
.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
1 change: 1 addition & 0 deletions docs/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Generic Problems
generic.SinkhornProblem
generic.GWProblem
generic.FGWProblem
generic.GENOTLinProblem

Plotting
~~~~~~~~
Expand Down
10 changes: 4 additions & 6 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,12 +400,10 @@ def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
Parameters
----------
x : ArrayLike
cond : Optional[ArrayLike], optional
Returns
-------
ArrayLike
x
Distribution to push.
cond
Condition of conditional neural OT.
Raises
------
Expand Down
44 changes: 21 additions & 23 deletions src/moscot/base/output.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from __future__ import annotations

import abc
import copy
import functools
from abc import abstractmethod
from typing import (
Any,
Callable,
Iterable,
List,
Literal,
Optional,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, Iterable, Literal, Optional, Union

import numpy as np
import scipy.sparse as sp
Expand All @@ -23,19 +15,25 @@

__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"]

OutputClass = TypeVar("OutputClass", bound="BaseSolverOutput") # use string


class BaseSolverOutput(abc.ABC):
"""Base class for all solver outputs."""

@abc.abstractmethod
def pull(self, x: ArrayLike, **kwargs) -> ArrayLike:
"""Pull the solution based on a condition."""

@abc.abstractmethod
def push(self, x: ArrayLike, **kwargs) -> ArrayLike:
"""Push the solution based on a condition."""

@property
@abc.abstractmethod
def shape(self) -> Tuple[int, int]:
def shape(self) -> tuple[int, int]:
"""Shape of the problem."""

@abc.abstractmethod
def to(self: OutputClass, device: Optional[Device_t] = None) -> OutputClass:
def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput:
"""Transfer self to another compute device.
Parameters
Expand Down Expand Up @@ -83,15 +81,15 @@ def converged(self) -> bool:

@property
@abc.abstractmethod
def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]:
def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]:
""":term:`Dual potentials` :math:`f` and :math:`g`.
Only valid for the :term:`Sinkhorn` algorithm.
"""

@property
@abc.abstractmethod
def shape(self) -> Tuple[int, int]:
def shape(self) -> tuple[int, int]:
"""Shape of the :attr:`transport_matrix`."""

@property
Expand Down Expand Up @@ -179,7 +177,7 @@ def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator
# pull: X @ a (matvec)
return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push)

def chain(self, outputs: Iterable["BaseDiscreteSolverOutput"], scale_by_marginals: bool = False) -> LinearOperator:
def chain(self, outputs: Iterable[BaseDiscreteSolverOutput], scale_by_marginals: bool = False) -> LinearOperator:
"""Chain subsequent applications of :attr:`transport_matrix`.
Parameters
Expand All @@ -206,7 +204,7 @@ def sparsify(
batch_size: int = 1024,
n_samples: Optional[int] = None,
seed: Optional[int] = None,
) -> "MatrixSolverOutput":
) -> MatrixSolverOutput:
"""Sparsify the :attr:`transport_matrix`.
This function sets all entries of the transport matrix below a certain threshold to :math:`0` and
Expand Down Expand Up @@ -267,7 +265,7 @@ def sparsify(
raise NotImplementedError(f"Mode `{mode}` is not yet implemented.")

k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack)
tmaps_sparse: List[sp.csr_matrix] = []
tmaps_sparse: list[sp.csr_matrix] = []

for batch in range(0, k, batch_size):
x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float)
Expand Down Expand Up @@ -355,12 +353,12 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102
return self._transport_matrix

@property
def shape(self) -> Tuple[int, int]: # noqa: D102
def shape(self) -> tuple[int, int]: # noqa: D102
return self.transport_matrix.shape # type: ignore[return-value]

def to( # noqa: D102
self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None
) -> "BaseDiscreteSolverOutput":
) -> BaseDiscreteSolverOutput:
if device is not None:
logger.warning(f"`{self!r}` does not support the `device` argument, ignoring.")
if dtype is None:
Expand All @@ -379,7 +377,7 @@ def converged(self) -> bool: # noqa: D102
return self._converged

@property
def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102
def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: # noqa: D102
return None

@property
Expand Down
16 changes: 6 additions & 10 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,18 +1049,14 @@ def __str__(self) -> str:

class CondOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load
"""
Base class for all optimal transport problems.
Base class for all conditional (nerual) optimal transport problems.
Parameters
----------
adata
Source annotated data object.
kwargs
Keyword arguments for :class:`moscot.problems.base.BaseProblem.`
Notes
-----
If any of the source/target masks are specified, :attr:`adata_src`/:attr:`adata_tgt` will be a view.
Keyword arguments for :class:`moscot.base.problems.problem.BaseProblem`
"""

def __init__(
Expand Down Expand Up @@ -1101,7 +1097,7 @@ def prepare(
----------
xy
Geometry defining the linear term. If passed as a :class:`dict`,
:meth:`~moscot.solvers.TaggedArray.from_adata` will be called.
:meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called.
policy
Policy defining which pairs of distributions to sample from during training.
policy_key
Expand Down Expand Up @@ -1159,11 +1155,11 @@ def solve(
Parameters
----------
backend
Which backend to use, see :func:`moscot.backends.get_available_backends`.
Which backend to use, see :func:`moscot.backends.utils.get_available_backends`.
device
Device where to transfer the solution, see :meth:`moscot.solvers.BaseDiscreteSolverOutput.to`.
Device where to transfer the solution, see :meth:`moscot.base.output.BaseDiscreteSolverOutput.to`.
kwargs
Keyword arguments for :meth:`moscot.solvers.BaseSolver.__call__`.
Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`.
Returns
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/base/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from moscot._logging import logger
from moscot._types import ArrayLike, Device_t, ProblemKind_t
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.output import BaseDiscreteSolverOutput, BaseSolverOutput
from moscot.utils.tagged_array import Tag, TaggedArray

__all__ = ["BaseSolver", "OTSolver"]
Expand Down Expand Up @@ -151,7 +151,7 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]:
def _partition_kwargs(cls, **kwargs: Any) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Partition keyword arguments.
Used by the :meth:`~moscot.problems.base.BaseProblem.solve`.
Used by the :meth:`~moscot.base.problems.problem.BaseProblem.solve`.
Parameters
----------
Expand Down Expand Up @@ -180,7 +180,7 @@ def __call__(
tags: Mapping[Literal["x", "y", "xy"], Tag] = types.MappingProxyType({}),
device: Optional[Device_t] = None,
**kwargs: Any,
) -> O:
) -> BaseSolverOutput:
"""Solve an optimal transport problem.
Parameters
Expand Down
8 changes: 4 additions & 4 deletions src/moscot/problems/generic/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def push(
Parameters
----------
source
Source key in :attr:`solutions`.
Source key in `solutions`.
target
Target key in :attr:`solutions`.
Target key in `solutions`.
data
Initial data to push, see :meth:`~moscot.base.problems.OTProblem.push` for information.
subset
Expand Down Expand Up @@ -212,9 +212,9 @@ def pull(
Parameters
----------
source
Source key in :attr:`solutions`.
Source key in `solutions`.
target
Target key in :attr:`solutions`.
Target key in `solutions`.
data
Initial data to pull, see :meth:`~moscot.base.problems.OTProblem.pull` for information.
subset
Expand Down
26 changes: 8 additions & 18 deletions src/moscot/utils/tagged_array.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import enum
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Hashable,
Literal,
Optional,
Tuple,
TypeVar,
Union,
)
from typing import Any, Callable, Hashable, Literal, Optional, Tuple, TypeVar, Union

import numpy as np
import scipy.sparse as sp
Expand Down Expand Up @@ -325,21 +315,21 @@ def from_adata(
b
Marginals when used as target distribution.
xy_attr
Attribute of :paramref:`adata` containing the data for the shared space.
Attribute of `adata` containing the data for the shared space.
xy_key
Key of :paramref:`xy_attr` containing the data for the shared space.
Key of `xy_attr` containing the data for the shared space.
xy_cost
Cost function when in the shared space.
xx_attr
Attribute of :paramref:`adata` containing the data for the incomparable space.
Attribute of `adata` containing the data for the incomparable space.
xx_key
Key of :paramref:`xx_attr` containing the data for the incomparable space.
Key of `xx_attr` containing the data for the incomparable space.
xx_cost
Cost function in the incomparable space.
conditions_attr
Attribute of :paramref:`adata` containing the conditions.
Attribute of `adata` containing the conditions.
conditions_key
Key of :paramref:`conditions_attr` containing the conditions.
Key of `conditions_attr` containing the conditions.
backend
Backend to use.
kwargs
Expand Down Expand Up @@ -373,7 +363,7 @@ def from_adata(
return cls(xy=xy_data, xx=xx_data, a=a, b=b, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn)


class DistributionCollection(Dict[K, DistributionContainer]):
class DistributionCollection(dict[K, DistributionContainer]):
"""Collection of distributions."""

def __repr__(self) -> str:
Expand Down

0 comments on commit b9b66d5

Please sign in to comment.