From 6e65ac14dc98f36459caf7ad5b90580429893f95 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Wed, 21 Aug 2024 09:10:39 -0400 Subject: [PATCH] (feat): Linear GENOT (#662) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix jaxsampler * fix jaxsampler * fix jaxsampler * fix tests * add plot_convergence * remove jit from _compute_unbalanced marginals * fix sinkhorn_divergence * adapt tox.ini file * shape mismatch fixed without precommit * remove print statement * finish merge * adapt callbacks and rename tag `cost` to `cost_matrix` (#426) * rename tag cost to cost_matrix * fix renaming * [CI skip], adapt callback * incorporate requested changes * add test for quad custom callback * adapt kwargs for callback * fix handle_joint_attr * incorporate requested changes * Feature/correlation test (#423) * fix test in FGWProblem * add correlation test * add first test for correlation computation * add more tests * fix tests * add tfs to compute_feature_correlation * add testing for no nans in compute_feature_correlation * incorporate requested changes * fix docstring * fix sankey return statement (#428) * fix sankey return statement * adapt test * adapt return_fig * Bump version: 0.1.0 → 0.1.1 * fix return statements * add save tests * fix return type in mpl (#432) * fix return type in mpl * change import acronyms * fix tests * Simplify linear operator (#431) * Simplify linear operator * Simplify `align`, fix test * Explicitly jit the solvers (#433) * Feature/interpolate colors sankey (#434) * fix return type in mpl * change import acronyms * fix tests * add interpolation option to sankey * add test to interpolate color * define colors for pull/push` * adapt tests * introduce axes in mpl.push/pull * incorporate requested changes * change default color * adapt plotting * introduce scaling * fix scale * make start/end categorical in plot * regenerate images * Remove `FGWSolver` (#437) * Remove `FGWSolver` * Fix `tox.ini` * Fix wrong shape check * Use pure GW in generic solver * Update tests * fix bug in SinkhornProblem (#442) * fix bug in SinkhornProblem * fix tox.ini * fix pre commits * make push/pull always use source/target (#443) * make push/pull always use source/target * fix bug in StarPolicy _apply * adapt plotting to source/target * fix strip plotting in sankey (#445) * fix strip plotting in sankey * simplify code * Feature/spearman correlation (#444) * add spearman correlation * add tests * adapt tests * Delete logo.png * Feature/plot order (#453) * make push/pull plot in good order * [CI skip], try setting adata.uns color explicitly * [CI skip], fix copying of adata * fix pre commits * fix bug * Expose marginal kwargs for `moscot.temporal` and check for numeric type of `temporal_key` (#449) * make marginal_kwargs explicit in temporal problems * introduce check for numeric dtype in temporal mixin * add alternative way for marginal prior * adapt tolerances in tests * correct docs * fix bug * Fix math rendering * fix test Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com> * adapt plot_convergence (#454) * Bug/docs generic analysis mixin (#455) * adapt plot_convergence * remove temporal-alluding docs in generic analysis mixin * Docs/improvements (#456) * adapt plot_convergence * remove temporal-alluding docs in generic analysis mixin * docs suggestions * remove uns_key from set_plotting_vars (#458) * resolve `fig referenced before assignment` (#460) * move generic mixins tests to problems` (#461) * Tests/spatiotemporalproblem (#464) * add more tests for spatiotemporalProblem * move some functions from TemporalProblem to TemporalMixin * add tests LineageProblem * fix tests * Feature/move taggedarray (#457) * adapt plot_convergence * remove temporal-alluding docs in generic analysis mixin * docs suggestions * move tagged array * move taggedarray back to solvers * add marginal_kwargs to prepare method of TemporalNeuralProblem * fix to scaling in * Revert "fix to scaling in" This reverts commit 0a6f7dbdd70b954694f109456274febe6ee46c0c. * fix to scaling argument in marginal_kwargs * updated conditional not pipeline * merge into condot branch * incoporated comments * incoporated comments * incoporated comments * removed new_adata for push/pull * [ci skip] start docs * added temporal neural test * [ci skip] continue docs * continue docs * continue docs * change validation epsilon * fixed error when not computing wasserstein baseline * fixed error when not computing wasserstein baseline * correct typo * fix bug * added neural tests * [ci skip] draft CondNeuralOutput * include CondDualPotentials and CondDualSolver * fixes to main merge * fix test_cell_transition_subset_pipeline * fix tests * update conditionalDualPotentials * update conditionalDualPotentials * fix most pre-commit hooks and fix tests * fix pandas version to <2.0 * fix tests for non-conditional solvers * continue * fix * continue fixing * fix ICNN setup * fix tests * swap role of f and g, such that push/pull is correct again * [ci skip] restructure to include more general neural solvers * [ci skip] restructure ICNNs to allow passing instances of ICNN * adapt tests * Filled in Monge Gap structure * Added Monge Gap paper to documentation * Ammend PointCloud Import * Update _utils.py Ammend PointCloud import * Solve compatibility issue with ProblemKind * Solve missing Import * Fix call to deprecated function * Fix style and comment issues * add callback, swap f & g * add callback, swap f & g * add callback, swap f & g * intermediate save * intermediate save * intermediate save * [ci skip] fix merge conflicts * resolve conflict * remove pairwise policy * add neural dependencies * add neural dependencies * add flax * fix _call_kwargs * fix marginal kwargs * remove monge gap solver * clean condneuralsolver * [ci skip] introduce new data container for joint neural problems * add conditions in distirbutioncontainer * resolve unfreeze/freeze * enable pretraining and weight clipping * make dicts compatible with older python versions * resolve precommit errors partially * resolve precommit errors partially * adapt tests * [ci skip] draft unbalancedNeuralMixin * [ci skip] fix naming of posterior marginals * [ci skip] add MLP_marginals * adapt neural output to incorporate learnt rescaling functions * fix _solve in neuraldualsolver * incorporate feedback * fix distributioncollection class * unify _split_data * fix tests * fix some precommit hooks * make neural dependencies optional * make neural dependencies optional * delete old files * adapt pyproject.toml * adapt pyproject.toml * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [ci skip] adjust _format_params * adapt neuraldualsolver to be more similar to ott-jax * adapt neuraldualsolver * TODO: make JaxSampler return conditions * add basic neural test * [ci skip] intermediate save * adapt neuraldualsolver and finish tests for neural backend * [ci skip] TODO: re-iterate on initialisation of neural solver * adapt distributioncontainer * fix dict bug * resolve passing of arguments in solver call methods * [ci skip] adapt `solve` in `CondOTProblem` * adapt tests and valid loader conditions * adapt neural backend tests * fix mypy errors * make basesolveroutput to basediscretesolveroutput * move `to` to BaseSolverOutput` " * adapt transport_matrix docs * adapt transport_matrix docs * adapt tests * adapt tests * update unbalancedness mixin * use implementation from moscot * uncomment unused code * before passing states to loss-fn * intermediate save * adapt neuraldualsolver * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * resolve some / not all pre commit errors * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (wip): tests run, code swapped out for now * (wip): `NeuralSolver`s implemented minus quad/linear * (wip): begin more generic problem * (wip): more refactoring to pass arguments to GENOT * (chore): remove more kantorovich * (chore): update branch to moscot neural + first test moving to solving * (fix): split data remains in numpy * (fix): push/pull api * (fix): make push test work * (feat): allow for custom optimziers * (chore): remove unclear test * (refactor): change to composition API * (refactor): start towards model-specific problems * (chore): clean up all unnecessary classes * (chore): updating to moscot latest * Merge branch 'main' into ig/neural_solvers * (chore): remove (hopefully) final ICNN vestiges * (chore): more cleanup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * (fix): pass pre-commit hooks * (chore): remove duplicatec docs * (chore): add torch for testing * (fix): add ott jax branch as dep * (fix): repo name * (chore): remove unbalanced, update api, fix tests + drive by typing fix * (feat): first pass at neural mixin * (chore): add my name to todos * (fix): conditions left out if not necessary * (feat): logs and fix conditional attr * (fix): add `seed` to call_kwargs so reproducibility works * (chore): remove `is_conditional` business * (fix): create hidden dims arg for velocity field * (chore): raise not implemented error for `pull` * (fix): default args * (fix): add explicit policy * (fix): allow iteration to continue * (chore): add star policy to GENOT * (chore): notebooks * (chore): remove deps * (chore): remove unnecessary spaces * (chore): simplify quad handling * (fix): need to require `optax`/`flax` * (fix): use `ott-jax[neural]` * (chore): fix docs * (fix): small test fixes * (chore): small notebook changes * (fix): broken link in citation * (chore): make notebook dependent on ci * (fix): small todos just to push something * (fix): variable is a string * (fix): pass environment variable to tox * (fix): actually pass through * (fix): hidden dims ci * (fix): re-add notebook * (chore): make`recall_target` and `aggregate_to_topk` * (chore): fix default arguments * (chore): `project_transport_matrix` -> `project_to_transport_matrix` * (fix): remove dead `NeuralAnalysisMixin` code * (feat): allow custom `data_match_fn` * (fix): inherit from `MutableMapping` instead of `dict` * (Fix): docs * (fix): notebooks * (fix): docs reference * (fix): remove `attr` * (fix): erroneous change * (fix): remove empty * (fix): notebooks again? * (chore): ok? --------- Co-authored-by: Dominik Klein Co-authored-by: Dominik Klein Co-authored-by: AlejandroTL Co-authored-by: michalk8 <46717574+michalk8@users.noreply.github.com> Co-authored-by: lucaeyring Co-authored-by: gocato <104785310+gocato@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/test.yml | 2 +- docs/conf.py | 12 + docs/developer.rst | 8 +- docs/installation.rst | 2 +- docs/notebooks | 2 +- docs/references.bib | 9 + docs/user.rst | 1 + pyproject.toml | 9 +- src/moscot/backends/ott/__init__.py | 7 +- src/moscot/backends/ott/_utils.py | 115 +++++++- src/moscot/backends/ott/output.py | 251 ++++++++++++++++- src/moscot/backends/ott/solver.py | 252 ++++++++++++++++- src/moscot/backends/utils.py | 26 +- src/moscot/base/output.py | 137 ++++++--- src/moscot/base/problems/__init__.py | 3 +- src/moscot/base/problems/_mixins.py | 4 +- src/moscot/base/problems/compound_problem.py | 4 +- src/moscot/base/problems/manager.py | 6 +- src/moscot/base/problems/problem.py | 264 +++++++++++++++++- src/moscot/base/solver.py | 20 +- src/moscot/problems/__init__.py | 2 + src/moscot/problems/_utils.py | 79 +++++- .../problems/cross_modality/_translation.py | 2 +- src/moscot/problems/generic/__init__.py | 14 +- src/moscot/problems/generic/_generic.py | 91 +++++- src/moscot/problems/generic/_mixins.py | 8 +- src/moscot/problems/space/_alignment.py | 2 +- src/moscot/problems/space/_mapping.py | 2 +- src/moscot/problems/space/_mixins.py | 4 +- .../spatiotemporal/_spatio_temporal.py | 2 +- src/moscot/problems/time/_lineage.py | 4 +- src/moscot/utils/subset_policy.py | 5 + src/moscot/utils/tagged_array.py | 191 ++++++++++++- tests/backends/ott/test_backend.py | 6 +- tests/problems/base/test_general_problem.py | 6 +- tests/problems/conftest.py | 7 + .../test_translation_problem.py | 4 +- .../test_conditional_neural_problem.py | 86 ++++++ tests/problems/generic/test_fgw_problem.py | 4 +- tests/problems/generic/test_gw_problem.py | 4 +- .../problems/generic/test_sinkhorn_problem.py | 4 +- .../test_spatio_temporal_problem.py | 4 +- tests/problems/time/test_lineage_problem.py | 4 +- tests/problems/time/test_temporal_problem.py | 6 +- tests/solvers/test_base_solver.py | 2 +- 45 files changed, 1530 insertions(+), 147 deletions(-) create mode 100644 tests/problems/generic/test_conditional_neural_problem.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c57a981c1..e27a02960 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.8", "3.10"] + python: ["3.9", "3.10"] include: - os: macos-latest python: "3.9" diff --git a/docs/conf.py b/docs/conf.py index 98d7c7237..311d1f6b2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -63,12 +63,24 @@ nitpicky = True nitpick_ignore = [ ("py:class", "numpy.float64"), + # see: https://github.com/numpy/numpydoc/issues/275 + ("py:class", "None. Remove all items from D."), + ("py:class", "a set-like object providing a view on D's items"), + ("py:class", "a set-like object providing a view on D's keys"), + ("py:class", "v, remove specified key and return the corresponding value."), # noqa: E501 + ("py:class", "None. Update D from dict/iterable E and F."), + ("py:class", "an object providing a view on D's values"), + ("py:class", "a shallow copy of D"), ] # TODO(michalk8): remove once typing has been cleaned-up nitpick_ignore_regex = [ (r"py:class", r"moscot\..*(K|B|O)"), (r"py:class", r"numpy\._typing.*"), (r"py:class", r"moscot\..*Protocol.*"), + ( + r"py:class", + r"moscot.base.output.BaseSolverOutput", + ), # https://github.com/sphinx-doc/sphinx/issues/10974 means there is simply no way around this with generics ] diff --git a/docs/developer.rst b/docs/developer.rst index 349d1f559..26466b03a 100644 --- a/docs/developer.rst +++ b/docs/developer.rst @@ -12,6 +12,8 @@ Backends backends.ott.GWSolver backends.ott.OTTOutput backends.ott.GraphOTTOutput + backends.ott.GENOTLinSolver + backends.ott.output.OTTNeuralOutput backends.utils.get_solver backends.utils.get_available_backends @@ -44,6 +46,7 @@ Problems problems.BaseCompoundProblem problems.CompoundProblem cost.BaseCost + problems.CondOTProblem Mixins ^^^^^^ @@ -62,14 +65,13 @@ Solvers solver.BaseSolver solver.OTSolver - output.BaseSolverOutput Output ^^^^^^ .. autosummary:: :toctree: genapi - output.BaseSolverOutput + output.BaseDiscreteSolverOutput output.MatrixSolverOutput Utils @@ -100,6 +102,8 @@ Miscellaneous data.apoptosis_markers tagged_array.TaggedArray tagged_array.Tag + tagged_array.DistributionCollection + tagged_array.DistributionContainer .. currentmodule:: moscot.base.problems .. autosummary:: diff --git a/docs/installation.rst b/docs/installation.rst index 12c9946e9..5d0e416b5 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,6 +1,6 @@ Installation ============ -:mod:`moscot` requires Python version >= 3.8 to run. +:mod:`moscot` requires Python version >= 3.9 to run. PyPI ---- diff --git a/docs/notebooks b/docs/notebooks index 5b9d4e07d..c48edf3d0 160000 --- a/docs/notebooks +++ b/docs/notebooks @@ -1 +1 @@ -Subproject commit 5b9d4e07d7188b1c391ec47a1c5d957da1ab2bca +Subproject commit c48edf3d0acb6dc191bb571320357b9119a6c559 diff --git a/docs/references.bib b/docs/references.bib index 3863d9673..47331da70 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -477,3 +477,12 @@ @article{srivatsan:20 year={2020}, publisher={American Association for the Advancement of Science} } + +@misc{klein2023generative, + title={Generative Entropic Neural Optimal Transport To Map Within and Across Spaces}, + author={Dominik Klein and Théo Uscidda and Fabian Theis and Marco Cuturi}, + year={2023}, + eprint={2310.09254}, + archivePrefix={arXiv}, + primaryClass={stat.ML} +} diff --git a/docs/user.rst b/docs/user.rst index 8c6b2d59a..f4291892a 100644 --- a/docs/user.rst +++ b/docs/user.rst @@ -27,6 +27,7 @@ Generic Problems generic.SinkhornProblem generic.GWProblem generic.FGWProblem + generic.GENOTLinProblem Plotting ~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 196ac594d..014775ffe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "moscot" dynamic = ["version"] description = "Multi-omic single-cell optimal transport tools" readme = "README.rst" -requires-python = ">=3.8" +requires-python = ">=3.9" license = {file = "LICENSE"} classifiers = [ "Development Status :: 4 - Beta", @@ -19,7 +19,6 @@ classifiers = [ "Operating System :: Microsoft :: Windows", "Typing :: Typed", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Topic :: Scientific/Engineering :: Bio-Informatics", @@ -55,7 +54,7 @@ dependencies = [ "scanpy>=1.9.3", "wrapt>=1.13.2", "docrep>=0.3.2", - "ott-jax>=0.4.6", + "ott-jax[neural]>=0.4.6", "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0" @@ -263,16 +262,16 @@ max_line_length = 120 legacy_tox_ini = """ [tox] min_version = 4.0 -env_list = lint-code,py{3.8,3.9,3.10,3.11} +env_list = lint-code,py{3.9,3.10,3.11} skip_missing_interpreters = true [testenv] extras = test -pass_env = PYTEST_*,CI commands = python -m pytest {tty:--color=yes} {posargs: \ --cov={env_site_packages_dir}{/}moscot --cov-config={tox_root}{/}pyproject.toml \ --no-cov-on-fail --cov-report=xml --cov-report=term-missing:skip-covered} +passenv = PYTEST_*,CI [testenv:lint-code] description = Lint the code. diff --git a/src/moscot/backends/ott/__init__.py b/src/moscot/backends/ott/__init__.py index 40f1cba6c..48ffdec64 100644 --- a/src/moscot/backends/ott/__init__.py +++ b/src/moscot/backends/ott/__init__.py @@ -1,11 +1,12 @@ from ott.geometry import costs from moscot.backends.ott._utils import sinkhorn_divergence -from moscot.backends.ott.output import GraphOTTOutput, OTTOutput -from moscot.backends.ott.solver import GWSolver, SinkhornSolver +from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput +from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver from moscot.costs import register_cost -__all__ = ["OTTOutput", "GraphOTTOutput", "GWSolver", "SinkhornSolver", "sinkhorn_divergence"] +__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"] + register_cost("euclidean", backend="ott")(costs.Euclidean) register_cost("sq_euclidean", backend="ott")(costs.SqEuclidean) diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 1934cdb2c..2cac53b30 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -1,11 +1,16 @@ -from typing import Any, Literal, Optional, Tuple, Union +from collections import defaultdict +from functools import partial +from typing import Any, Dict, Iterable, Literal, Optional, Tuple, Union import jax import jax.experimental.sparse as jesp import jax.numpy as jnp +import numpy as np import scipy.sparse as sp from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud -from ott.tools import sinkhorn_divergence as sdiv +from ott.neural import datasets +from ott.solvers import utils as solver_utils +from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div from moscot._logging import logger from moscot._types import ArrayLike, ScaleCost_t @@ -22,7 +27,10 @@ def sinkhorn_divergence( a: Optional[ArrayLike] = None, b: Optional[ArrayLike] = None, epsilon: Union[float, epsilon_scheduler.Epsilon] = 1e-1, + tau_a: float = 1.0, + tau_b: float = 1.0, scale_cost: ScaleCost_t = 1.0, + batch_size: Optional[int] = None, **kwargs: Any, ) -> float: point_cloud_1 = jnp.asarray(point_cloud_1) @@ -30,14 +38,16 @@ def sinkhorn_divergence( a = None if a is None else jnp.asarray(a) b = None if b is None else jnp.asarray(b) - output = sdiv.sinkhorn_divergence( + output = sinkhorn_div( pointcloud.PointCloud, x=point_cloud_1, y=point_cloud_2, + batch_size=batch_size, a=a, b=b, - epsilon=epsilon, + sinkhorn_kwargs={"tau_a": tau_a, "tau_b": tau_b}, scale_cost=scale_cost, + epsilon=epsilon, **kwargs, ) xy_conv, xx_conv, *yy_conv = output.converged @@ -52,6 +62,23 @@ def sinkhorn_divergence( return float(output.divergence) +@partial(jax.jit, static_argnames=["k"]) +def get_nearest_neighbors( + input_batch: jnp.ndarray, + target: jnp.ndarray, + k: int = 30, + recall_target: float = 0.95, + aggregate_to_topk: bool = True, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Get the k nearest neighbors of the input batch in the target.""" + if target.shape[0] < k: + raise ValueError(f"k is {k}, but must be smaller or equal than {target.shape[0]}.") + pairwise_euclidean_distances = pointcloud.PointCloud(input_batch, target).cost_matrix + return jax.lax.approx_min_k( + pairwise_euclidean_distances, k=k, recall_target=recall_target, aggregate_to_topk=aggregate_to_topk + ) + + def check_shapes(geom_x: geometry.Geometry, geom_y: geometry.Geometry, geom_xy: geometry.Geometry) -> None: n, m = geom_xy.shape n_, m_ = geom_x.shape[0], geom_y.shape[0] @@ -133,3 +160,83 @@ def _instantiate_geodesic_cost( cm_full = geodesic.Geodesic.from_graph(arr, t=t, directed=directed, **kwargs).cost_matrix cm = cm_full[:n_src, n_src:] if is_linear_term else cm_full return geometry.Geometry(cm, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost) + + +def data_match_fn( + src_lin: Optional[jnp.ndarray] = None, + tgt_lin: Optional[jnp.ndarray] = None, + src_quad: Optional[jnp.ndarray] = None, + tgt_quad: Optional[jnp.ndarray] = None, + *, + typ: Literal["lin", "quad", "fused"], + **data_match_fn_kwargs, +) -> jnp.ndarray: + if typ == "lin": + return solver_utils.match_linear(x=src_lin, y=tgt_lin, **data_match_fn_kwargs) + if typ == "quad": + return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, **data_match_fn_kwargs) + if typ == "fused": + return solver_utils.match_quadratic(xx=src_quad, yy=tgt_quad, x=src_lin, y=tgt_lin, **data_match_fn_kwargs) + raise NotImplementedError(f"Unknown type: {typ}.") + + +class Loader: + + def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None): + self.dataset = dataset + self.batch_size = batch_size + self._rng = np.random.default_rng(seed) + + def __iter__(self): + return self + + def __next__(self) -> Dict[str, jnp.ndarray]: + data = defaultdict(list) + for _ in range(self.batch_size): + ix = self._rng.integers(0, len(self.dataset)) + for k, v in self.dataset[ix].items(): + data[k].append(v) + return {k: jnp.vstack(v) for k, v in data.items()} + + def __len__(self): + return len(self.dataset) + + +class MultiLoader: + """Dataset for OT problems with conditions. + + This data loader wraps several data loaders and samples from them. + + Args: + datasets: Datasets to sample from. + seed: Random seed. + """ + + def __init__( + self, + datasets: Iterable[Loader], + seed: Optional[int] = None, + ): + self.datasets = tuple(datasets) + self._rng = np.random.default_rng(seed) + self._iterators: list[MultiLoader] = [] + self._it = 0 + + def __next__(self) -> Dict[str, jnp.ndarray]: + self._it += 1 + + ix = self._rng.choice(len(self._iterators)) + iterator = self._iterators[ix] + if self._it < len(self): + return next(iterator) + # reset the consumed iterator and return it's first element + self._iterators[ix] = iterator = iter(self.datasets[ix]) + return next(iterator) + + def __iter__(self) -> "MultiLoader": + self._it = 0 + self._iterators = [iter(ds) for ds in self.datasets] + return self + + def __len__(self) -> int: + return max((len(ds) for ds in self.datasets), default=0) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 57f6500ac..60d727faf 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -1,10 +1,12 @@ -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union import jaxlib.xla_extension as xla_ext import jax import jax.numpy as jnp import numpy as np +import scipy.sparse as sp +from ott.neural.methods.flows.genot import GENOT from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr @@ -12,12 +14,13 @@ import matplotlib.pyplot as plt from moscot._types import ArrayLike, Device_t -from moscot.base.output import BaseSolverOutput +from moscot.backends.ott._utils import get_nearest_neighbors +from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput -__all__ = ["OTTOutput", "GraphOTTOutput"] +__all__ = ["OTTOutput", "GraphOTTOutput", "OTTNeuralOutput"] -class OTTOutput(BaseSolverOutput): +class OTTOutput(BaseDiscreteSolverOutput): """Output of various :term:`OT` problems. Parameters @@ -238,6 +241,246 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102 return jnp.ones((n,)) +class OTTNeuralOutput(BaseNeuralOutput): + """Output wrapper for GENOT.""" + + def __init__(self, model: GENOT, logs: dict[str, list[float]]): + """Initialize `OTTNeuralOutput`. + + Parameters + ---------- + model : GENOT + The OTT-Jax GENOT model + """ + self._logs = logs + self._model = model + + @property + def logs(self): + """Logs of the training. A dictionary containing what the numeric values are i.e., loss. + + Returns + ------- + dict[str, list[float]] + """ + return self._logs + + def _project_transport_matrix( + self, + src_dist: ArrayLike, + tgt_dist: ArrayLike, + forward: bool, + func: Callable[[jnp.ndarray], jnp.ndarray], + save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments + batch_size: int = 1024, + k: int = 30, + length_scale: Optional[float] = None, + seed: int = 42, + recall_target: float = 0.95, + aggregate_to_topk: bool = True, + ) -> sp.csr_matrix: + row_indices: Union[jnp.ndarray, List[jnp.ndarray]] = [] + column_indices: Union[jnp.ndarray, List[jnp.ndarray]] = [] + distances_list: Union[jnp.ndarray, List[jnp.ndarray]] = [] + if length_scale is None: + key = jax.random.PRNGKey(seed) + src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))] + tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))] + length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch))) + for index in range(0, len(src_dist), batch_size): + distances, indices = get_nearest_neighbors( + func(src_dist[index : index + batch_size, :]), + tgt_dist, + k, + recall_target=recall_target, + aggregate_to_topk=aggregate_to_topk, + ) + distances = jnp.exp(-((distances / length_scale) ** 2)) + distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1) + distances_list.append(distances.flatten()) + column_indices.append(indices.flatten()) + row_indices.append( + jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist))) + ) + distances = jnp.concatenate(distances_list) + row_indices = jnp.concatenate(row_indices) + column_indices = jnp.concatenate(column_indices) + tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)]) + if forward: + if save_transport_matrix: + self._transport_matrix = tm + else: + tm = tm.T + if save_transport_matrix: + self._inverse_transport_matrix = tm + return tm + + def project_to_transport_matrix( # type:ignore[override] + self, + src_cells: ArrayLike, + tgt_cells: ArrayLike, + forward: bool = True, + condition: ArrayLike = None, + save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments + batch_size: int = 1024, + k: int = 30, + length_scale: Optional[float] = None, + seed: int = 42, + recall_target: float = 0.95, + aggregate_to_topk: bool = True, + ) -> sp.csr_matrix: + """Project conditional neural OT map onto cells. + + In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells, + but a cell can also be mapped to a location between two cells. This function computes + a pseudo-transport matrix considering the neighborhood of where a cell is mapped to. + Therefore, a neighborhood graph of `k` target cells is computed around each transported cell + of the source distribution. The assignment likelihood of each mapped cell to the target cells is then + computed with a Gaussian kernel with parameter `length_scale`. + + Parameters + ---------- + condition + Condition `src_cells` correspond to. + src_cells + Cells which are to be mapped. + tgt_cells + Cells from which the neighborhood graph around the mapped `src_cells` are computed. + forward + Whether to map cells based on the forward transport map or backward transport map. + save_transport_matrix + Whether to save the transport matrix. + batch_size + Number of data points in the source distribution the neighborhoodgraph is computed + for in parallel. + k + Number of neighbors to construct the k-nearest neighbor graph of a mapped cell. + length_scale + Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`, + `length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the + mapped source and target distribution. + seed + Random seed for sampling the pairs of distributions for computing the variance in case `length_scale` + is `None`. + recall_target + Recall target for the approximation. + aggregate_to_topk + When true, the nearest neighbor aggregates approximate results to the top-k in sorted order. + When false, returns the approximate results unsorted. + In this case, the number of the approximate results is implementation defined and is greater or + equal to the specified k. + + Returns + ------- + The projected transport matrix. + """ + src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells) + push = self.push if condition is None else lambda x: self.push(x, condition) + pull = self.pull if condition is None else lambda x: self.pull(x, condition) + func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells) + return self._project_transport_matrix( + src_dist=src_dist, + tgt_dist=tgt_dist, + forward=forward, + func=func, + save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments + batch_size=batch_size, + k=k, + length_scale=length_scale, + seed=seed, + recall_target=recall_target, + aggregate_to_topk=aggregate_to_topk, + ) + + def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: + """Push distribution `x` conditioned on condition `cond`. + + Parameters + ---------- + x + Distribution to push. + cond + Condition of conditional neural OT. + + Returns + ------- + Pushed distribution. + """ + if x.ndim not in (1, 2): + raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.") + return self._apply(x, cond=cond, forward=True) + + def pull(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: + """Pull distribution `x` conditioned on condition `cond`. + + This does not make sense for some neural models and is therefore left unimplemented. + + Parameters + ---------- + x + Distribution to push. + cond + Condition of conditional neural OT. + + Raises + ------ + NotImplementedError + """ + raise NotImplementedError("`pull` does not make sense for neural OT.") + + def _apply(self, x: ArrayLike, forward: bool, cond: Optional[ArrayLike] = None) -> ArrayLike: + if not forward: + raise NotImplementedError("Backward i.e., pull on neural OT is not supported.") + return self._model.transport(x, condition=cond) + + @property + def is_linear(self) -> bool: # noqa: D102 + return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT + + @property + def shape(self) -> Tuple[int, int]: + """%(shape)s.""" + raise NotImplementedError() + + def to( + self, + device: Optional[Device_t] = None, + ) -> "OTTNeuralOutput": + """Transfer the output to another device or change its data type. + + Parameters + ---------- + device + If not `None`, the output will be transferred to `device`. + + Returns + ------- + The output on a saved on `device`. + """ + # # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep + # if isinstance(device, str) and ":" in device: + # device, ix = device.split(":") + # idx = int(ix) + # else: + # idx = 0 + + # if not isinstance(device, xla_ext.Device): + # try: + # device = jax.devices(device)[idx] + # except IndexError as err: + # raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err + + # out = jax.device_put(self._model, device) + # return OTTNeuralOutput(out) + return self # TODO(ilan-gold) move model to device + + @property + def converged(self) -> bool: + """%(converged)s.""" + # always return True for now + return True + + class GraphOTTOutput(OTTOutput): """Output of :term:`OT` problems with a graph geometry in the linear term. diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 9d280d892..732404dbc 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -1,32 +1,63 @@ import abc +import functools import inspect +import math import types -from typing import Any, Literal, Mapping, Optional, Set, Tuple, Union +from typing import ( + Any, + Hashable, + List, + Literal, + Mapping, + NamedTuple, + Optional, + Set, + Tuple, + TypeVar, + Union, +) + +import optax import jax import jax.numpy as jnp +import numpy as np +import scipy.sparse as sp from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud +from ott.neural.datasets import OTData, OTDataset +from ott.neural.methods.flows import dynamics, genot +from ott.neural.networks.layers import time_encoder +from ott.neural.networks.velocity_field import VelocityField from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr +from ott.solvers.utils import uniform_sampler -from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t +from moscot._types import ( + ArrayLike, + ProblemKind_t, + QuadInitializer_t, + SinkhornInitializer_t, +) from moscot.backends.ott._utils import ( + Loader, + MultiLoader, _instantiate_geodesic_cost, alpha_to_fused_penalty, check_shapes, convert_scipy_sparse, + data_match_fn, densify, ensure_2d, ) -from moscot.backends.ott.output import GraphOTTOutput, OTTOutput +from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput from moscot.base.problems._utils import TimeScalesHeatKernel from moscot.base.solver import OTSolver from moscot.costs import get_cost -from moscot.utils.tagged_array import TaggedArray +from moscot.utils.tagged_array import DistributionCollection, TaggedArray -__all__ = ["SinkhornSolver", "GWSolver"] +__all__ = ["SinkhornSolver", "GWSolver", "GENOTLinSolver"] OTTSolver_t = Union[ sinkhorn.Sinkhorn, @@ -36,6 +67,18 @@ ] OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem] Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]] +K = TypeVar("K", bound=Hashable) + + +class SingleDistributionData(NamedTuple): + data_train: ArrayLike + data_valid: ArrayLike + conditions_train: Optional[ArrayLike] + conditions_valid: Optional[ArrayLike] + a_train: Optional[ArrayLike] + a_valid: Optional[ArrayLike] + b_train: Optional[ArrayLike] + b_valid: Optional[ArrayLike] class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC): @@ -457,3 +500,202 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"} problem_kwargs |= {"alpha"} return geom_kwargs | problem_kwargs, {"epsilon"} + + +class GENOTLinSolver(OTSolver[OTTOutput]): + """Solver class for genot.GENOT linear :cite:`klein2023generative`.""" + + def __init__(self, **kwargs: Any) -> None: + """Initiate the class with any kwargs passed to the ott-jax class.""" + super().__init__() + self._train_sampler: Optional[MultiLoader] = None + self._valid_sampler: Optional[MultiLoader] = None + self._neural_kwargs = kwargs + + @property + def problem_kind(self) -> ProblemKind_t: # noqa: D102 + return "linear" + + def _prepare( # type: ignore[override] + self, + distributions: DistributionCollection[K], + sample_pairs: List[Tuple[Any, Any]], + train_size: float = 0.9, + batch_size: int = 1024, + is_conditional: bool = True, + **kwargs: Any, + ) -> Tuple[MultiLoader, MultiLoader]: + train_loaders = [] + validate_loaders = [] + seed = kwargs.get("seed", None) + is_aligned = kwargs.get("is_aligned", False) + if train_size == 1.0: + for sample_pair in sample_pairs: + source_key = sample_pair[0] + target_key = sample_pair[1] + src_data = OTData( + lin=distributions[source_key].xy, + condition=distributions[source_key].conditions if is_conditional else None, + ) + tgt_data = OTData( + lin=distributions[target_key].xy, + condition=distributions[target_key].conditions if is_conditional else None, + ) + dataset = OTDataset(src_data=src_data, tgt_data=tgt_data, seed=seed, is_aligned=is_aligned) + loader = Loader(dataset, batch_size=batch_size, seed=seed) + train_loaders.append(loader) + validate_loaders.append(loader) + else: + if train_size > 1.0 or train_size <= 0.0: + raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1") + + seed = kwargs.get("seed", 0) + for sample_pair in sample_pairs: + source_key = sample_pair[0] + target_key = sample_pair[1] + source_data: ArrayLike = distributions[source_key].xy + target_data: ArrayLike = distributions[target_key].xy + source_split_data = self._split_data( + source_data, + conditions=distributions[source_key].conditions, + train_size=train_size, + seed=seed, + a=distributions[source_key].a, + b=distributions[source_key].b, + ) + target_split_data = self._split_data( + target_data, + conditions=distributions[target_key].conditions, + train_size=train_size, + seed=seed, + a=distributions[target_key].a, + b=distributions[target_key].b, + ) + src_data_train = OTData( + lin=source_split_data.data_train, + condition=source_split_data.conditions_train if is_conditional else None, + ) + tgt_data_train = OTData( + lin=target_split_data.data_train, + condition=target_split_data.conditions_train if is_conditional else None, + ) + train_dataset = OTDataset( + src_data=src_data_train, tgt_data=tgt_data_train, seed=seed, is_aligned=is_aligned + ) + train_loader = Loader(train_dataset, batch_size=batch_size, seed=seed) + src_data_validate = OTData( + lin=source_split_data.data_valid, + condition=source_split_data.conditions_valid if is_conditional else None, + ) + tgt_data_validate = OTData( + lin=target_split_data.data_valid, + condition=target_split_data.conditions_valid if is_conditional else None, + ) + validate_dataset = OTDataset( + src_data=src_data_validate, tgt_data=tgt_data_validate, seed=seed, is_aligned=is_aligned + ) + validate_loader = Loader(validate_dataset, batch_size=batch_size, seed=seed) + train_loaders.append(train_loader) + validate_loaders.append(validate_loader) + source_dim = self._neural_kwargs.get("input_dim", 0) + target_dim = source_dim + condition_dim = self._neural_kwargs.get("cond_dim", 0) + # TODO(ilan-gold): What are reasonable defaults here? + neural_vf = VelocityField( + output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim], + condition_dims=( + self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim]) + if is_conditional + else None + ), + hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]), + time_dims=self._neural_kwargs.get("velocity_field_time_dims", None), + time_encoder=self._neural_kwargs.get( + "velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024) + ), + ) + seed = self._neural_kwargs.get("seed", 0) + rng = jax.random.PRNGKey(seed) + data_match_fn_kwargs = self._neural_kwargs.get( + "data_match_fn_kwargs", + {} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0}, + ) + time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler) + optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4)) + self._solver = genot.GENOT( + vf=neural_vf, + flow=self._neural_kwargs.get( + "flow", + dynamics.ConstantNoiseFlow(0.1), + ), + data_match_fn=functools.partial( + self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs + ), + source_dim=source_dim, + target_dim=target_dim, + condition_dim=condition_dim if is_conditional else None, + optimizer=optimizer, + time_sampler=time_sampler, + rng=rng, + latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None), + **self._neural_kwargs.get("velocity_field_train_state_kwargs", {}), + ) + return ( + MultiLoader(datasets=train_loaders, seed=seed), + MultiLoader(datasets=validate_loaders, seed=seed), + ) + + @staticmethod + def _assert2d(arr: ArrayLike, *, allow_reshape: bool = True) -> jnp.ndarray: + arr: jnp.ndarray = jnp.asarray(arr.A if sp.issparse(arr) else arr) # type: ignore[no-redef, attr-defined] # noqa:E501 + if allow_reshape and arr.ndim == 1: + return jnp.reshape(arr, (-1, 1)) + if arr.ndim != 2: + raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.") + return arr + + def _split_data( # TODO: adapt for Gromov terms + self, + x: ArrayLike, + conditions: Optional[ArrayLike], + train_size: float, + seed: int, + a: Optional[ArrayLike] = None, + b: Optional[ArrayLike] = None, + ) -> SingleDistributionData: + n_samples_x = x.shape[0] + n_train_x = math.ceil(train_size * n_samples_x) + rng = np.random.default_rng(seed) + x = rng.permutation(x) + if a is not None: + a = rng.permutation(a) + if b is not None: + b = rng.permutation(b) + + return SingleDistributionData( + data_train=x[:n_train_x], + data_valid=x[n_train_x:], + conditions_train=conditions[:n_train_x] if conditions is not None else None, + conditions_valid=conditions[n_train_x:] if conditions is not None else None, + a_train=a[:n_train_x] if a is not None else None, + a_valid=a[n_train_x:] if a is not None else None, + b_train=b[:n_train_x] if b is not None else None, + b_valid=b[n_train_x:] if b is not None else None, + ) + + @property + def solver(self) -> genot.GENOT: + """Underlying optimal transport solver.""" + return self._solver + + @classmethod + def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: + return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value] + + def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> OTTNeuralOutput: # type: ignore[override] + seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests + rng = jax.random.PRNGKey(seed) + logs = self.solver( + data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng + ) # TODO(ilan-gold): validation and figure out defualts + return OTTNeuralOutput(self.solver, logs) diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index a8881d2fc..fde874c0f 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union from moscot import _registry from moscot._types import ProblemKind_t @@ -8,6 +8,11 @@ __all__ = ["get_solver", "register_solver", "get_available_backends"] +register_solver_t = Callable[ + [Literal["linear", "quadratic"], Optional[Literal["GENOTLinSolver"]]], + Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"], +] + _REGISTRY = _registry.Registry() @@ -16,13 +21,13 @@ def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_clas """TODO.""" if backend not in _REGISTRY: raise ValueError(f"Backend `{backend!r}` is not available.") - solver_class = _REGISTRY[backend](problem_kind) + solver_class = _REGISTRY[backend](problem_kind, solver_name=kwargs.pop("solver_name", None)) return solver_class if return_class else solver_class(**kwargs) def register_solver( backend: str, -) -> Callable[[Literal["linear", "quadratic"]], Union[Type["ott.SinkhornSolver"], Type["ott.GWSolver"]]]: +) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]: """Register a solver for a specific backend. Parameters @@ -37,15 +42,22 @@ def register_solver( return _REGISTRY.register(backend) # type: ignore[return-value] +# TODO(@MUCDK) fix mypy error @register_solver("ott") # type: ignore[arg-type] -def _(problem_kind: Literal["linear", "quadratic"]) -> Union[Type["ott.SinkhornSolver"], Type["ott.GWSolver"]]: +def _( + problem_kind: Literal["linear", "quadratic"], + solver_name: Optional[Literal["GENOTLinSolver"]] = None, +) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]: from moscot.backends import ott if problem_kind == "linear": - return ott.SinkhornSolver + if solver_name == "GENOTLinSolver": + return ott.GENOTLinSolver # type: ignore[return-value] + if solver_name is None: + return ott.SinkhornSolver # type: ignore[return-value] if problem_kind == "quadratic": - return ott.GWSolver - raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}` problem.") + return ott.GWSolver # type: ignore[return-value] + raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`, {solver_name} problem.") def get_available_backends() -> Tuple[str, ...]: diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index f1bfa78cc..9617b8729 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import abc import copy import functools -from typing import Any, Callable, Iterable, List, Literal, Optional, Tuple, Union +from abc import abstractmethod +from typing import Any, Callable, Iterable, Literal, Optional, Union import numpy as np import scipy.sparse as sp @@ -10,12 +13,53 @@ from moscot._logging import logger from moscot._types import ArrayLike, Device_t, DTypeLike # type: ignore[attr-defined] -__all__ = ["BaseSolverOutput", "MatrixSolverOutput"] +__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"] class BaseSolverOutput(abc.ABC): """Base class for all solver outputs.""" + @abc.abstractmethod + def pull(self, x: ArrayLike, **kwargs) -> ArrayLike: + """Pull the solution based on a condition.""" + + @abc.abstractmethod + def push(self, x: ArrayLike, **kwargs) -> ArrayLike: + """Push the solution based on a condition.""" + + @property + @abc.abstractmethod + def shape(self) -> tuple[int, int]: + """Shape of the problem.""" + + @abc.abstractmethod + def to(self: BaseSolverOutput, device: Optional[Device_t] = None) -> BaseSolverOutput: + """Transfer self to another compute device. + + Parameters + ---------- + device + Device where to transfer the solver output. If :obj:`None`, use the default device. + + Returns + ------- + Self transferred to the ``device``. + """ + + def _format_params(self, fmt: Callable[[Any], str]) -> str: + params = {"shape": self.shape} + return ", ".join(f"{name}={fmt(val)}" for name, val in params.items()) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}[{self._format_params(repr)}]" + + def __str__(self) -> str: + return f"{self.__class__.__name__}[{self._format_params(str)}]" + + +class BaseDiscreteSolverOutput(BaseSolverOutput, abc.ABC): + """Base class for all discrete solver outputs.""" + @abc.abstractmethod def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike: """Apply :attr:`transport_matrix` to an array of shape ``[n, d]`` or ``[m, d]``.""" @@ -25,11 +69,6 @@ def _apply(self, x: ArrayLike, *, forward: bool) -> ArrayLike: def transport_matrix(self) -> ArrayLike: """Transport matrix of shape ``[n, m]``.""" - @property - @abc.abstractmethod - def shape(self) -> Tuple[int, int]: - """Shape of the :attr:`transport_matrix`.""" - @property @abc.abstractmethod def cost(self) -> float: @@ -42,7 +81,7 @@ def converged(self) -> bool: @property @abc.abstractmethod - def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: + def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: """:term:`Dual potentials` :math:`f` and :math:`g`. Only valid for the :term:`Sinkhorn` algorithm. @@ -50,22 +89,13 @@ def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: @property @abc.abstractmethod - def is_linear(self) -> bool: - """Whether the output is a solution to a :term:`linear problem`.""" + def shape(self) -> tuple[int, int]: + """Shape of the :attr:`transport_matrix`.""" + @property @abc.abstractmethod - def to(self, device: Optional[Device_t] = None) -> "BaseSolverOutput": - """Transfer self to another compute device. - - Parameters - ---------- - device - Device where to transfer the solver output. If :obj:`None`, use the default device. - - Returns - ------- - Self transferred to the ``device``. - """ + def is_linear(self) -> bool: + """Whether the output is a solution to a :term:`linear problem`.""" @property def rank(self) -> int: @@ -147,7 +177,7 @@ def as_linear_operator(self, scale_by_marginals: bool = False) -> LinearOperator # pull: X @ a (matvec) return LinearOperator(shape=self.shape, dtype=self.dtype, matvec=pull, rmatvec=push) - def chain(self, outputs: Iterable["BaseSolverOutput"], scale_by_marginals: bool = False) -> LinearOperator: + def chain(self, outputs: Iterable[BaseDiscreteSolverOutput], scale_by_marginals: bool = False) -> LinearOperator: """Chain subsequent applications of :attr:`transport_matrix`. Parameters @@ -174,7 +204,7 @@ def sparsify( batch_size: int = 1024, n_samples: Optional[int] = None, seed: Optional[int] = None, - ) -> "MatrixSolverOutput": + ) -> MatrixSolverOutput: """Sparsify the :attr:`transport_matrix`. This function sets all entries of the transport matrix below a certain threshold to :math:`0` and @@ -235,7 +265,7 @@ def sparsify( raise NotImplementedError(f"Mode `{mode}` is not yet implemented.") k, func, fn_stack = (n, self.push, sp.vstack) if n < m else (m, self.pull, sp.hstack) - tmaps_sparse: List[sp.csr_matrix] = [] + tmaps_sparse: list[sp.csr_matrix] = [] for batch in range(0, k, batch_size): x = np.eye(k, min(batch_size, k - batch), -(min(batch, k)), dtype=float) @@ -268,6 +298,10 @@ def dtype(self) -> DTypeLike: """Underlying data type.""" return self.a.dtype + def _format_params(self, fmt: Callable[[Any], str]) -> str: + params = {"shape": self.shape, "cost": round(self.cost, 4), "converged": self.converged} + return ", ".join(f"{name}={fmt(val)}" for name, val in params.items()) + def _scale_by_marginals(self, x: ArrayLike, *, forward: bool, eps: float = 1e-12) -> ArrayLike: # alt. we could use the public push/pull marginals = self.a if forward else self.b @@ -275,21 +309,11 @@ def _scale_by_marginals(self, x: ArrayLike, *, forward: bool, eps: float = 1e-12 marginals = marginals[:, None] return x / (marginals + eps) - def _format_params(self, fmt: Callable[[Any], str]) -> str: - params = {"shape": self.shape, "cost": round(self.cost, 4), "converged": self.converged} - return ", ".join(f"{name}={fmt(val)}" for name, val in params.items()) - def __bool__(self) -> bool: return self.converged - def __repr__(self) -> str: - return f"{self.__class__.__name__}[{self._format_params(repr)}]" - - def __str__(self) -> str: - return f"{self.__class__.__name__}[{self._format_params(str)}]" - -class MatrixSolverOutput(BaseSolverOutput): +class MatrixSolverOutput(BaseDiscreteSolverOutput): """:term:`OT` solution with a materialized transport matrix. Parameters @@ -329,12 +353,12 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102 return self._transport_matrix @property - def shape(self) -> Tuple[int, int]: # noqa: D102 + def shape(self) -> tuple[int, int]: # noqa: D102 return self.transport_matrix.shape # type: ignore[return-value] def to( # noqa: D102 self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None - ) -> "BaseSolverOutput": + ) -> BaseDiscreteSolverOutput: if device is not None: logger.warning(f"`{self!r}` does not support the `device` argument, ignoring.") if dtype is None: @@ -353,7 +377,7 @@ def converged(self) -> bool: # noqa: D102 return self._converged @property - def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102 + def potentials(self) -> Optional[tuple[ArrayLike, ArrayLike]]: # noqa: D102 return None @property @@ -367,3 +391,38 @@ def _ones(self, n: int) -> ArrayLike: import jax.numpy as jnp return jnp.ones((n,), dtype=self.transport_matrix.dtype) + + +class BaseNeuralOutput(BaseDiscreteSolverOutput, abc.ABC): + """Base class for output of.""" + + @abstractmethod + def project_to_transport_matrix( + self, + source: Optional[ArrayLike] = None, + target: Optional[ArrayLike] = None, + condition: Optional[ArrayLike] = None, + forward: bool = True, + save_transport_matrix: bool = False, + batch_size: int = 1024, + k: int = 30, + length_scale: Optional[float] = None, + seed: int = 42, + ) -> sp.csr_matrix: + """Project transport matrix.""" + pass + + @property + def transport_matrix(self): # noqa: D102 + raise NotImplementedError("Neural output does not require a transport matrix.") + + @property + def cost(self): # noqa: D102 + raise NotImplementedError("Neural output does not implement a cost property.") + + @property + def potentials(self): # noqa: D102 + raise NotImplementedError("Neural output does not need to implement a potentials property.") + + def _ones(self, n: int): # noqa: D102 + raise NotImplementedError("Neural output does not need to implement a `_ones` property.") diff --git a/src/moscot/base/problems/__init__.py b/src/moscot/base/problems/__init__.py index b554b3ce5..544631d48 100644 --- a/src/moscot/base/problems/__init__.py +++ b/src/moscot/base/problems/__init__.py @@ -2,7 +2,7 @@ from moscot.base.problems.birth_death import BirthDeathMixin, BirthDeathProblem from moscot.base.problems.compound_problem import BaseCompoundProblem, CompoundProblem from moscot.base.problems.manager import ProblemManager -from moscot.base.problems.problem import BaseProblem, OTProblem +from moscot.base.problems.problem import BaseProblem, CondOTProblem, OTProblem __all__ = [ "AnalysisMixin", @@ -13,4 +13,5 @@ "ProblemManager", "BaseProblem", "OTProblem", + "CondOTProblem", ] diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 0cbdcb6cd..9c460eac2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -23,7 +23,7 @@ from moscot import _constants from moscot._types import ArrayLike, Numeric_t, Str_Dict_t -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems._utils import ( _check_argument_compatibility_cell_transition, _correlation_test, @@ -45,7 +45,7 @@ class AnalysisMixinProtocol(Protocol[K, B]): adata: AnnData _policy: SubsetPolicy[K] - solutions: dict[tuple[K, K], BaseSolverOutput] + solutions: dict[tuple[K, K], BaseDiscreteSolverOutput] problems: dict[tuple[K, K], B] def _apply( diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 7caea0b46..6070b34c7 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -24,7 +24,7 @@ from moscot._logging import logger from moscot._types import ArrayLike, Policy_t, ProblemStage_t -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems._utils import attributedispatch, require_prepare from moscot.base.problems.manager import ProblemManager from moscot.base.problems.problem import BaseProblem, OTProblem @@ -512,7 +512,7 @@ def remove_problem(self, key: Tuple[K, K]) -> "BaseCompoundProblem[K, B]": return self @property - def solutions(self) -> Dict[Tuple[K, K], BaseSolverOutput]: + def solutions(self) -> Dict[Tuple[K, K], BaseDiscreteSolverOutput]: """Solutions to the :attr:`problems`.""" if self._problem_manager is None: return {} diff --git a/src/moscot/base/problems/manager.py b/src/moscot/base/problems/manager.py index 1994f27d8..bdf4c59e9 100644 --- a/src/moscot/base/problems/manager.py +++ b/src/moscot/base/problems/manager.py @@ -11,7 +11,7 @@ ) from moscot._types import ProblemStage_t -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems.problem import OTProblem from moscot.utils.subset_policy import SubsetPolicy @@ -144,7 +144,7 @@ def get_problems( stage = (stage,) if isinstance(stage, str) else stage return {k: v for k, v in self.problems.items() if v.stage in stage} - def get_solutions(self, only_converged: bool = False) -> Dict[Tuple[K, K], BaseSolverOutput]: + def get_solutions(self, only_converged: bool = False) -> Dict[Tuple[K, K], BaseDiscreteSolverOutput]: """Get solutions to the :term:`OT` subproblems. Parameters @@ -174,7 +174,7 @@ def _verify_shape_integrity(self) -> None: raise ValueError(f"Problem `{key}` is associated with different dimensions: `{dim}`.") @property - def solutions(self) -> Dict[Tuple[K, K], BaseSolverOutput]: + def solutions(self) -> Dict[Tuple[K, K], BaseDiscreteSolverOutput]: """Solutions for the :term:`OT` :attr:`problems`.""" return self.get_solutions(only_converged=False) diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index eb0a340de..708b35d44 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -5,11 +5,15 @@ TYPE_CHECKING, Any, Dict, + Hashable, + Iterable, List, Literal, Mapping, Optional, + Sequence, Tuple, + TypeVar, Union, ) @@ -29,7 +33,7 @@ from moscot import backends from moscot._logging import logger from moscot._types import ArrayLike, CostFn_t, Device_t, ProblemKind_t -from moscot.base.output import BaseSolverOutput, MatrixSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput, MatrixSolverOutput from moscot.base.problems._utils import ( TimeScalesHeatKernel, _assert_columns_and_index_match, @@ -39,9 +43,23 @@ wrap_solve, ) from moscot.base.solver import OTSolver -from moscot.utils.tagged_array import Tag, TaggedArray +from moscot.utils.subset_policy import ( # type:ignore[attr-defined] + ExplicitPolicy, + Policy_t, + StarPolicy, + SubsetPolicy, + create_policy, +) +from moscot.utils.tagged_array import ( + DistributionCollection, + DistributionContainer, + Tag, + TaggedArray, +) -__all__ = ["BaseProblem", "OTProblem"] +K = TypeVar("K", bound=Hashable) + +__all__ = ["BaseProblem", "OTProblem", "CondOTProblem"] class CombinedMeta(abc.ABCMeta, NumpyDocstringInheritanceMeta): @@ -252,8 +270,8 @@ def __init__( self._src_key = src_key self._tgt_key = tgt_key - self._solver: Optional[OTSolver[BaseSolverOutput]] = None - self._solution: Optional[BaseSolverOutput] = None + self._solver: Optional[OTSolver[BaseDiscreteSolverOutput]] = None + self._solution: Optional[BaseDiscreteSolverOutput] = None self._x: Optional[TaggedArray] = None self._y: Optional[TaggedArray] = None @@ -383,6 +401,7 @@ def prepare( def solve( self, backend: Literal["ott"] = "ott", + solver_name: Optional[str] = None, device: Optional[Device_t] = None, **kwargs: Any, ) -> "OTProblem": @@ -392,8 +411,10 @@ def solve( ---------- backend Which backend to use, see :func:`~moscot.backends.utils.get_available_backends`. + solver_name + Literal defining the solver. If `None`, automatically infers the discrete OT solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :class:`~moscot.base.solver.BaseSolver` or its @@ -406,7 +427,9 @@ def solve( - :attr:`solver` - the :term:`OT` solver. - :attr:`solution` - the :term:`OT` solution. """ - solver_class = backends.get_solver(self.problem_kind, backend=backend, return_class=True) + solver_class = backends.get_solver( + self.problem_kind, solver_name=solver_name, backend=backend, return_class=True + ) init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) # if linear problem, then alpha is 0.0 by default # if quadratic problem, then alpha is 1.0 by default @@ -427,6 +450,7 @@ def solve( self._solver = solver_class(**init_kwargs) + # note that the solver call consists of solver._prepare and solver._solve self._solution = self._solver( # type: ignore[misc] xy=self._xy, x=self._x, @@ -449,7 +473,7 @@ def push( split_mass: bool = False, scale_by_marginals: bool = False, ) -> ArrayLike: - r"""Push data through the :attr:`~moscot.base.output.BaseSolverOutput.transport_matrix`. + r"""Push data through the :attr:`~moscot.base.output.BaseDiscreteSolverOutput.transport_matrix`. Parameters ---------- @@ -479,7 +503,7 @@ def push( The transported values, array of shape ``[m, d]``. """ if TYPE_CHECKING: - assert isinstance(self.solution, BaseSolverOutput) + assert isinstance(self.solution, BaseDiscreteSolverOutput) data = self._get_mass(self.adata_src, data=data, subset=subset, normalize=normalize, split_mass=split_mass) return self.solution.push(data, scale_by_marginals=scale_by_marginals) @@ -493,7 +517,7 @@ def pull( split_mass: bool = False, scale_by_marginals: bool = False, ) -> ArrayLike: - r"""Pull data through the :attr:`~moscot.base.output.BaseSolverOutput.transport_matrix`. + r"""Pull data through the :attr:`~moscot.base.output.BaseDiscreteSolverOutput.transport_matrix`. Parameters ---------- @@ -523,12 +547,16 @@ def pull( The transported values, array of shape ``[n, d]``. """ if TYPE_CHECKING: - assert isinstance(self.solution, BaseSolverOutput) + assert isinstance(self.solution, BaseDiscreteSolverOutput) data = self._get_mass(self.adata_tgt, data=data, subset=subset, normalize=normalize, split_mass=split_mass) return self.solution.pull(data, scale_by_marginals=scale_by_marginals) def set_solution( - self, solution: Union[ArrayLike, pd.DataFrame, BaseSolverOutput], *, overwrite: bool = False, **kwargs: Any + self, + solution: Union[ArrayLike, pd.DataFrame, BaseDiscreteSolverOutput], + *, + overwrite: bool = False, + **kwargs: Any, ) -> "OTProblem": """Set a :attr:`solution` to the :term:`OT` problem. @@ -557,7 +585,7 @@ def set_solution( _assert_series_match(self.adata_src.obs_names.to_series(), solution.index.to_series()) _assert_series_match(self.adata_tgt.obs_names.to_series(), solution.columns.to_series()) solution = solution.to_numpy() - if not isinstance(solution, BaseSolverOutput): + if not isinstance(solution, BaseDiscreteSolverOutput): solution = MatrixSolverOutput(solution, **kwargs) if solution.shape != self.shape: @@ -980,12 +1008,12 @@ def shape(self) -> Tuple[int, int]: return self.adata_src.n_obs, self.adata_tgt.n_obs @property - def solution(self) -> Optional[BaseSolverOutput]: + def solution(self) -> Optional[BaseDiscreteSolverOutput]: """Solution of the :term:`OT` problem.""" return self._solution @property - def solver(self) -> Optional[OTSolver[BaseSolverOutput]]: + def solver(self) -> Optional[OTSolver[BaseDiscreteSolverOutput]]: """:term:`OT` solver.""" return self._solver @@ -1019,3 +1047,209 @@ def __repr__(self) -> str: def __str__(self) -> str: return repr(self) + + +class CondOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load + """ + Base class for all conditional (nerual) optimal transport problems. + + Parameters + ---------- + adata + Source annotated data object. + kwargs + Keyword arguments for :class:`moscot.base.problems.problem.BaseProblem` + """ + + def __init__( + self, + adata: AnnData, + **kwargs: Any, + ): + super().__init__(**kwargs) + self._adata = adata + + self._distributions: Optional[DistributionCollection[K]] = None # type: ignore[valid-type] + self._policy: Optional[SubsetPolicy[Any]] = None + self._sample_pairs: Optional[List[Tuple[Any, Any]]] = None + + self._solver: Optional[OTSolver[BaseDiscreteSolverOutput]] = None + self._solution: Optional[BaseDiscreteSolverOutput] = None + + self._a: Optional[str] = None + self._b: Optional[str] = None + + @wrap_prepare + def prepare( + self, + policy_key: str, + policy: Policy_t, + xy: Mapping[str, Any], + xx: Mapping[str, Any], + conditions: Mapping[str, Any], + a: Optional[str] = None, + b: Optional[str] = None, + subset: Optional[Sequence[Tuple[K, K]]] = None, + reference: K = None, + **kwargs: Any, + ) -> "CondOTProblem": + """Prepare conditional optimal transport problem. + + Parameters + ---------- + xy + Geometry defining the linear term. If passed as a :class:`dict`, + :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. + policy + Policy defining which pairs of distributions to sample from during training. + policy_key + %(key)s + a + Source marginals. + b + Target marginals. + kwargs + Keyword arguments when creating the source/target marginals. + + + Returns + ------- + Self and modifies the following attributes: + TODO. + """ + self._problem_kind = "linear" + self._distributions = DistributionCollection() + self._solution = None + self._policy_key = policy_key + try: + self._distribution_id = pd.Series(self.adata.obs[policy_key]) + except KeyError: + raise KeyError(f"Unable to find data in `adata.obs[{policy_key!r}]`.") from None + + self._policy = create_policy(policy, adata=self.adata, key=policy_key) + if isinstance(self._policy, ExplicitPolicy): + self._policy = self._policy.create_graph(subset=subset) + elif isinstance(self._policy, StarPolicy): + self._policy = self._policy.create_graph(reference=reference) + else: + _ = self.policy.create_graph() # type: ignore[union-attr] + self._sample_pairs = list(self.policy._graph) # type: ignore[union-attr] + + for el in self.policy.categories: # type: ignore[union-attr] + adata_masked = self.adata[self._create_mask(el)] + a_created = self._create_marginals(adata_masked, data=a, source=True, **kwargs) + b_created = self._create_marginals(adata_masked, data=b, source=False, **kwargs) + self.distributions[el] = DistributionContainer.from_adata( # type: ignore[index] + adata_masked, a=a_created, b=b_created, **xy, **xx, **conditions + ) + return self + + @wrap_solve + def solve( + self, + backend: Literal["ott"] = "ott", + solver_name: Literal["GENOTLinSolver"] = "GENOTLinSolver", + device: Optional[Device_t] = None, + **kwargs: Any, + ) -> "CondOTProblem": + """Solve optimal transport problem. + + Parameters + ---------- + backend + Which backend to use, see :func:`moscot.backends.utils.get_available_backends`. + device + Device where to transfer the solution, see :meth:`moscot.base.output.BaseDiscreteSolverOutput.to`. + kwargs + Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`. + + + Returns + ------- + Self and modifies the following attributes: + - :attr:`solver`: optimal transport solver. + - :attr:`solution`: optimal transport solution. + """ + tmp = next(iter(self.distributions)) # type: ignore[arg-type] + input_dim = self.distributions[tmp].xy.shape[1] # type: ignore[union-attr, index] + cond_dim = self.distributions[tmp].conditions.shape[1] # type: ignore[union-attr, index] + + solver_class = backends.get_solver( + self.problem_kind, solver_name=solver_name, backend=backend, return_class=True + ) + init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) + self._solver = solver_class(input_dim=input_dim, cond_dim=cond_dim, **init_kwargs) + # note that the solver call consists of solver._prepare and solver._solve + sample_pairs = self._sample_pairs if self._sample_pairs is not None else [] + self._solution = self._solver( # type: ignore[misc] + device=device, + distributions=self.distributions, + sample_pairs=self._sample_pairs, + is_conditional=len(sample_pairs) > 1, + **call_kwargs, + ) + + return self + + def _create_marginals( + self, adata: AnnData, *, source: bool, data: Optional[str] = None, **kwargs: Any + ) -> ArrayLike: + if data is True: + marginals = self.estimate_marginals(adata, source=source, **kwargs) + elif data in (False, None): + marginals = np.ones((adata.n_obs,), dtype=float) / adata.n_obs + elif isinstance(data, str): + try: + marginals = np.asarray(adata.obs[data], dtype=float) + except KeyError: + raise KeyError(f"Unable to find data in `adata.obs[{data!r}]`.") from None + return marginals + + def _create_mask(self, value: Union[K, Sequence[K]], *, allow_empty: bool = False) -> ArrayLike: + """Create a mask used to subset the data. + + TODO(@MUCDK): this is copied from SubsetPolicy, consider making this a function. + + Parameters + ---------- + value + Values in the data which determine the mask. + allow_empty + Whether to allow empty mask. + + Returns + ------- + Boolean mask of the same shape as the data. + """ + if isinstance(value, str) or not isinstance(value, Iterable): + mask = self._distribution_id == value + else: + mask = self._distribution_id.isin(value) + if not allow_empty and not np.sum(mask): + raise ValueError("Unable to construct an empty mask, use `allow_empty=True` to override.") + return np.asarray(mask) + + @property + def distributions(self) -> Optional[DistributionCollection[K]]: + """Collection of distributions.""" + return self._distributions + + @property + def adata(self) -> AnnData: + """Source annotated data object.""" + return self._adata + + @property + def solution(self) -> Optional[BaseDiscreteSolverOutput]: + """Solution of the optimal transport problem.""" + return self._solution + + @property + def solver(self) -> Optional[OTSolver[BaseDiscreteSolverOutput]]: + """Solver of the optimal transport problem.""" + return self._solver + + @property + def policy(self) -> Optional[SubsetPolicy[Any]]: + """Policy used to subset the data.""" + return self._policy diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index 6ec316d76..0e69bd404 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -14,15 +14,17 @@ Union, ) +import numpy as np + from moscot._logging import logger from moscot._types import ArrayLike, Device_t, ProblemKind_t -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.utils.tagged_array import Tag, TaggedArray __all__ = ["BaseSolver", "OTSolver"] -O = TypeVar("O", bound=BaseSolverOutput) +O = TypeVar("O", bound=BaseDiscreteSolverOutput) class TaggedArrayData(NamedTuple): # noqa: D101 @@ -53,6 +55,9 @@ def to_tuple( loss_x = {k[2:]: v for k, v in kwargs.items() if k.startswith("x_")} loss_y = {k[2:]: v for k, v in kwargs.items() if k.startswith("y_")} + if isinstance(xy, dict) and np.all([isinstance(v, tuple) for v in xy.values()]): # handling joint learning + return xy + # fmt: off xy = xy if isinstance(xy, TaggedArray) else self._convert(*to_tuple(xy), tag=tags.get("xy", None), **loss_xy) x = x if isinstance(x, TaggedArray) else self._convert(*to_tuple(x), tag=tags.get("x", None), **loss_x) @@ -146,7 +151,7 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: def _partition_kwargs(cls, **kwargs: Any) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Partition keyword arguments. - Used by the :meth:`~moscot.problems.base.BaseProblem.solve`. + Used by the :meth:`~moscot.base.problems.problem.BaseProblem.solve`. Parameters ---------- @@ -189,7 +194,9 @@ def __call__( tags How to interpret the data in ``xy``, ``x`` and ``y``. device - Device to transfer the output to, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Device to transfer the output to, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. + is_conditional + Whether the OT problem is conditional. kwargs Keyword arguments for parent's :meth:`__call__`. @@ -197,8 +204,9 @@ def __call__( ------- The optimal transport solution. """ - data = self._get_array_data(xy=xy, x=x, y=y, tags=tags) - kwargs = {**kwargs, **self._untag(data)} + if not kwargs.get("is_conditional", False): # signals that this is a neural problem + data = self._get_array_data(xy=xy, x=x, y=y, tags=tags) + kwargs = {**kwargs, **self._untag(data)} res = super().__call__(**kwargs) if not res.converged: logger.warning("Solver did not converge") diff --git a/src/moscot/problems/__init__.py b/src/moscot/problems/__init__.py index b96b993b4..14f4422f5 100644 --- a/src/moscot/problems/__init__.py +++ b/src/moscot/problems/__init__.py @@ -1,4 +1,5 @@ from moscot.problems.cross_modality import TranslationProblem +from moscot.problems.generic import GENOTLinProblem from moscot.problems.space import AlignmentProblem, MappingProblem from moscot.problems.spatiotemporal import SpatioTemporalProblem from moscot.problems.time import LineageProblem, TemporalProblem @@ -10,4 +11,5 @@ "SpatioTemporalProblem", "LineageProblem", "TemporalProblem", + "GENOTLinProblem", ] diff --git a/src/moscot/problems/_utils.py b/src/moscot/problems/_utils.py index a7d9a9633..a405a3da6 100644 --- a/src/moscot/problems/_utils.py +++ b/src/moscot/problems/_utils.py @@ -1,7 +1,7 @@ import types from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union -from moscot._types import CostKwargs_t, OttCostFnMap_t +from moscot._types import CostFn_t, CostKwargs_t, OttCostFnMap_t from moscot.base.problems.compound_problem import Callback_t @@ -124,3 +124,80 @@ def handle_cost( if "y" in cost_candidates: y.update(cost_kwargs.get("y", cost_kwargs)) # type:ignore[call-overload] return xy, x, y + + +def handle_conditional_attr(conditional_attr: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, Any]: + if isinstance(conditional_attr, str): + conditional_attr = {"attr": "obsm", "key": conditional_attr} + elif isinstance(conditional_attr, Mapping): + conditional_attr = dict(conditional_attr) + if "attr" not in conditional_attr: + raise KeyError("`attr` must be provided when `conditional_attr` is a mapping.") + if conditional_attr["attr"] == "X": + conditions_attr = "X" + conditions_key = None + else: + if "key" not in conditional_attr: + raise KeyError("`key` must be provided when `attr` is not `X`.") + conditions_attr = conditional_attr["attr"] + conditions_key = conditional_attr["key"] + else: + raise TypeError("Expected `conditional_attr` to be either `str` or `dict`.") + return {"conditions_attr": conditions_attr, "conditions_key": conditions_key} + + +def handle_joint_attr_tmp( + joint_attr: Union[str, Mapping[str, Any]], kwargs: Dict[str, Any] +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if isinstance(joint_attr, str): + xy = { + "xy_attr": "obsm", + "xy_key": joint_attr, + } + return xy, kwargs + if isinstance(joint_attr, Mapping): # input mapping does not distinguish between x and y as it's a shared space + joint_attr = dict(joint_attr) + if "attr" in joint_attr and joint_attr["attr"] == "X": # we have a point cloud + return {"xy_attr": "X"}, kwargs + if "attr" in joint_attr and joint_attr["attr"] == "obsm": # we have a point cloud + if "key" not in joint_attr: + raise KeyError("`key` must be provided when `attr` is `obsm`.") + xy = { + "xy_attr": "obsm", + "xy_key": joint_attr["key"], + } + return xy, kwargs + + raise TypeError(f"Expected `joint_attr` to be either `str` or `dict`, found `{type(joint_attr)}`.") + + +def handle_cost_tmp( + xy: Mapping[str, Any] = types.MappingProxyType({}), + xx: Mapping[str, Any] = types.MappingProxyType({}), + cost: Optional[Union[CostFn_t, Mapping[str, CostFn_t]]] = None, + cost_kwargs: CostKwargs_t = types.MappingProxyType({}), + **_: Any, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + xy, xx = dict(xy), dict(xx) + if cost is None: + return xy, xx + if isinstance(cost, str): # if cost is a str, we use it in all terms + if xy and "cost" not in xy: + xy["xy_cost"] = cost + if xx and "cost" not in xx: + xx["xy_cost"] = cost + elif isinstance(cost, Mapping): # if cost is a dict, the cost is specified for each term + if xy and ("xy_cost" not in xy or "xx_cost" not in xy): + xy["xy_cost"] = cost["xy"] + if xx and "cost" not in xx: + xx["xx_cost"] = cost["xx_cost"] + else: + raise TypeError(f"Expected `cost` to be either `str` or `dict`, found `{type(cost)}`.") + if xy and cost_kwargs: # distribute the cost_kwargs, possibly explicit to x/y/xy-term + # extract cost_kwargs explicit to xy-term if possible + items = cost_kwargs["xy"].items() if "xy" in cost_kwargs else cost_kwargs.items() + for k, v in items: + xy[f"xy_{k}"] = xy[f"xy_{k}"] = v + if xx and cost_kwargs: # extract cost_kwargs explicit to x-term if possible + xx.update(cost_kwargs.get("xx", cost_kwargs)) # type:ignore[call-overload] + return xy, xx diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 26ed9ee66..5c8f9081c 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -275,7 +275,7 @@ def solve( # type: ignore[override] linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. diff --git a/src/moscot/problems/generic/__init__.py b/src/moscot/problems/generic/__init__.py index f0640cbf5..d96dc4db6 100644 --- a/src/moscot/problems/generic/__init__.py +++ b/src/moscot/problems/generic/__init__.py @@ -1,4 +1,14 @@ -from moscot.problems.generic._generic import FGWProblem, GWProblem, SinkhornProblem +from moscot.problems.generic._generic import ( + FGWProblem, + GENOTLinProblem, + GWProblem, + SinkhornProblem, +) from moscot.problems.generic._mixins import GenericAnalysisMixin -__all__ = ["FGWProblem", "GWProblem", "SinkhornProblem", "GenericAnalysisMixin"] +__all__ = [ + "FGWProblem" "SinkhornProblem", + "GENOTLinProblem", + "GWProblem", + "GenericAnalysisMixin", +] diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 4bd62e192..8d727d0d0 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -1,4 +1,5 @@ import types +from types import MappingProxyType from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Type, Union from anndata import AnnData @@ -15,11 +16,17 @@ SinkhornInitializer_t, ) from moscot.base.problems.compound_problem import B, Callback_t, CompoundProblem, K -from moscot.base.problems.problem import OTProblem -from moscot.problems._utils import handle_cost, handle_joint_attr +from moscot.base.problems.problem import CondOTProblem, OTProblem +from moscot.problems._utils import ( + handle_conditional_attr, + handle_cost, + handle_cost_tmp, + handle_joint_attr, + handle_joint_attr_tmp, +) from moscot.problems.generic._mixins import GenericAnalysisMixin -__all__ = ["SinkhornProblem", "GWProblem", "FGWProblem"] +__all__ = ["SinkhornProblem", "GWProblem", "GENOTLinProblem", "FGWProblem"] def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, str]: @@ -216,7 +223,7 @@ def solve( max_iterations Maximum number of :term:`Sinkhorn` iterations. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. @@ -452,7 +459,7 @@ def solve( linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. @@ -725,7 +732,7 @@ def solve( linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. @@ -759,3 +766,75 @@ def solve( device=device, **kwargs, ) + + @property + def _base_problem_type(self) -> Type[B]: + return OTProblem # type: ignore[return-value] + + @property + def _valid_policies(self) -> Tuple[Policy_t, ...]: + return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] + + +class GENOTLinProblem(CondOTProblem, GenericAnalysisMixin[K, B]): + """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" + + def prepare( + self, + key: str, + joint_attr: Union[str, Mapping[str, Any]], + conditional_attr: Union[str, Mapping[str, Any]], + policy: Literal["sequential", "star", "explicit"] = "sequential", + a: Optional[str] = None, + b: Optional[str] = None, + cost: OttCostFn_t = "sq_euclidean", + cost_kwargs: CostKwargs_t = types.MappingProxyType({}), + **kwargs: Any, + ) -> "GENOTLinProblem[K, B]": + """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" + self.batch_key = key # type:ignore[misc] + xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) + conditions = handle_conditional_attr(conditional_attr) + xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) + return super().prepare( + policy_key=key, + policy=policy, + xy=xy, + xx=xx, + conditions=conditions, + a=a, + b=b, + **kwargs, + ) + + def solve( + self, + batch_size: int = 1024, + seed: int = 0, + iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations + valid_freq: int = 50, + valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), + train_size: float = 1.0, + **kwargs: Any, + ) -> "GENOTLinProblem[K, B]": + """Solve.""" + return super().solve( + batch_size=batch_size, + # tau_a=tau_a, # TODO: unbalancedness handler + # tau_b=tau_b, + seed=seed, + n_iters=iterations, + valid_freq=valid_freq, + valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, + train_size=train_size, + solver_name="GENOTLinSolver", + **kwargs, + ) + + @property + def _base_problem_type(self) -> Type[CondOTProblem]: + return CondOTProblem + + @property + def _valid_policies(self) -> Tuple[Policy_t, ...]: + return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value] diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index 6580c2ad5..ed74ebffd 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -127,9 +127,9 @@ def push( Parameters ---------- source - Source key in :attr:`solutions`. + Source key in `solutions`. target - Target key in :attr:`solutions`. + Target key in `solutions`. data Initial data to push, see :meth:`~moscot.base.problems.OTProblem.push` for information. subset @@ -196,9 +196,9 @@ def pull( Parameters ---------- source - Source key in :attr:`solutions`. + Source key in `solutions`. target - Target key in :attr:`solutions`. + Target key in `solutions`. data Initial data to pull, see :meth:`~moscot.base.problems.OTProblem.pull` for information. subset diff --git a/src/moscot/problems/space/_alignment.py b/src/moscot/problems/space/_alignment.py index 2567878e7..5e8d2009e 100644 --- a/src/moscot/problems/space/_alignment.py +++ b/src/moscot/problems/space/_alignment.py @@ -255,7 +255,7 @@ def solve( linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index 8d4a0cb67..abe440e28 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -302,7 +302,7 @@ def solve( linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. Only used when `alpha` > 0. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index b6cd39538..d72c79c2b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -416,7 +416,7 @@ def correlate( # type: ignore[misc] - ``'spearman'`` - `Spearman rank correlation `_. device - Device where to transfer the solutions, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Device where to transfer the solutions, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. groupby Optional key in :attr:`~anndata.AnnData.obs`, containing categorical annotations for grouping. batch_size @@ -504,7 +504,7 @@ def impute( # type: ignore[misc] var_names Genes in :attr:`~anndata.AnnData.var_names` to impute. If :obj:`None`, use all genes in :attr:`adata_sc`. device - Device where to transfer the solutions, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Device where to transfer the solutions, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. batch_size: Number of features to process at once. If :obj:`None`, process all features at once. Larger values will require more memory. diff --git a/src/moscot/problems/spatiotemporal/_spatio_temporal.py b/src/moscot/problems/spatiotemporal/_spatio_temporal.py index caf2604f2..89228df63 100644 --- a/src/moscot/problems/spatiotemporal/_spatio_temporal.py +++ b/src/moscot/problems/spatiotemporal/_spatio_temporal.py @@ -242,7 +242,7 @@ def solve( linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.problems.space.AlignmentProblem.solve`. diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 5e1370d78..809b7f8b9 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -241,7 +241,7 @@ def solve( max_iterations Maximum number of :term:`Sinkhorn` iterations. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.base.problems.CompoundProblem.solve`. @@ -510,7 +510,7 @@ def solve( linear_solver_kwargs Keyword arguments for the inner :term:`linear problem` solver. device - Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseSolverOutput.to`. + Transfer the solution to a different device, see :meth:`~moscot.base.output.BaseDiscreteSolverOutput.to`. If :obj:`None`, keep the output on the original device. kwargs Keyword arguments for :meth:`~moscot.problems.time.TemporalProblem.solve`. diff --git a/src/moscot/utils/subset_policy.py b/src/moscot/utils/subset_policy.py index a5375de84..98e08476e 100644 --- a/src/moscot/utils/subset_policy.py +++ b/src/moscot/utils/subset_policy.py @@ -253,6 +253,11 @@ def remove_node(self, node: Tuple[K, K]) -> "SubsetPolicy[K]": self._graph.remove(node) return self + @property + def categories(self) -> Sequence[K]: + """Categories in the policy.""" + return self._cat + @property def key(self) -> Optional[str]: """Key in :attr:`~anndata.AnnData.obs` defining the policy.""" diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 0a1a776ad..050573534 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import Any, Callable, Literal, Optional, Tuple, Union +from typing import Any, Callable, Hashable, Literal, Optional, Tuple, TypeVar, Union import numpy as np import scipy.sparse as sp @@ -8,10 +8,12 @@ from anndata import AnnData from moscot._logging import logger -from moscot._types import ArrayLike, CostFn_t +from moscot._types import ArrayLike, CostFn_t, OttCostFn_t from moscot.costs import get_cost -__all__ = ["Tag", "TaggedArray"] +K = TypeVar("K", bound=Hashable) + +__all__ = ["Tag", "TaggedArray", "DistributionContainer", "DistributionCollection"] @enum.unique @@ -186,3 +188,186 @@ def is_point_cloud(self) -> bool: def is_graph(self) -> bool: """Whether :attr:`data_src` is a graph.""" return self.tag == Tag.GRAPH + + +@dataclass(frozen=True, repr=True) +class DistributionContainer: + """Data container for OT problems involving more than two distributions. + + TODO + + Parameters + ---------- + xy + Distribution living in a shared space. + xx + Distribution living in an incomparable space. + a + Marginals when used as source distribution. + b + Marginals when used as target distribution. + conditions + Conditions for the distributions. + cost_xy + Cost function when in the shared space. + cost_xx + Cost function in the incomparable space. + """ + + xy: Optional[ArrayLike] + xx: Optional[ArrayLike] + a: ArrayLike + b: ArrayLike + conditions: Optional[ArrayLike] + cost_xy: OttCostFn_t + cost_xx: OttCostFn_t + + @property + def contains_linear(self) -> bool: + """Whether the distribution contains data corresponding to the linear term.""" + return self.xy is not None + + @property + def contains_quadratic(self) -> bool: + """Whether the distribution contains data corresponding to the quadratic term.""" + return self.xx is not None + + @property + def contains_condition(self) -> bool: + """Whether the distribution contains data corresponding to the condition.""" + return self.conditions is not None + + @staticmethod + def _extract_data( + adata: AnnData, + *, + attr: Literal["X", "obs", "obsp", "obsm", "var", "varm", "layers", "uns"], + key: Optional[str] = None, + ) -> ArrayLike: + modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" + data = getattr(adata, attr) + + try: + if key is not None: + data = data[key] + except KeyError: + raise KeyError(f"Unable to fetch data from `{modifier}`.") from None + except IndexError: + raise IndexError(f"Unable to fetch data from `{modifier}`.") from None + + if attr == "obs": + data = np.asarray(data)[:, None] + if sp.issparse(data): + logger.warning(f"Densifying data in `{modifier}`") + data = data.A + if data.ndim != 2: + raise ValueError(f"Expected `{modifier}` to have `2` dimensions, found `{data.ndim}`.") + + return data + + @staticmethod + def _verify_input( + xy_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], + xy_key: Optional[str], + xx_attr: Optional[Literal["X", "obsp", "obsm", "layers", "uns"]], + xx_key: Optional[str], + conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]], + conditions_key: Optional[str], + ) -> Tuple[bool, bool, bool]: + if (xy_attr is None and xy_key is not None) or (xy_attr is not None and xy_key is None): + raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") + if (xx_attr is None and xx_key is not None) or (xx_attr is not None and xx_key is None): + raise ValueError(r"Either both `xy_attr` and `xy_key` must be `None` or none of them.") + if (conditions_attr is None and conditions_key is not None) or ( + conditions_attr is not None and conditions_key is None + ): + raise ValueError(r"Either both `conditions_attr` and `conditions_key` must be `None` or none of them.") + return xy_attr is not None, xx_attr is not None, conditions_attr is not None + + @classmethod + def from_adata( + cls, + adata: AnnData, + a: ArrayLike, + b: ArrayLike, + xy_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, + xy_key: Optional[str] = None, + xy_cost: CostFn_t = "sq_euclidean", + xx_attr: Literal["X", "obsp", "obsm", "layers", "uns"] = None, + xx_key: Optional[str] = None, + xx_cost: CostFn_t = "sq_euclidean", + conditions_attr: Optional[Literal["obs", "var", "obsm", "varm", "layers", "uns"]] = None, + conditions_key: Optional[str] = None, + backend: Literal["ott"] = "ott", + **kwargs: Any, + ) -> "DistributionContainer": + """Create distribution container from :class:`~anndata.AnnData`. + + .. warning:: + Sparse arrays will be always densified. + + Parameters + ---------- + adata + Annotated data object. + a + Marginals when used as source distribution. + b + Marginals when used as target distribution. + xy_attr + Attribute of `adata` containing the data for the shared space. + xy_key + Key of `xy_attr` containing the data for the shared space. + xy_cost + Cost function when in the shared space. + xx_attr + Attribute of `adata` containing the data for the incomparable space. + xx_key + Key of `xx_attr` containing the data for the incomparable space. + xx_cost + Cost function in the incomparable space. + conditions_attr + Attribute of `adata` containing the conditions. + conditions_key + Key of `conditions_attr` containing the conditions. + backend + Backend to use. + kwargs + Keyword arguments to pass to the cost functions. + + Returns + ------- + The distribution container. + """ + contains_linear, contains_quadratic, contains_condition = cls._verify_input( + xy_attr, xy_key, xx_attr, xx_key, conditions_attr, conditions_key + ) + + if contains_linear: + xy_data = cls._extract_data(adata, attr=xy_attr, key=xy_key) + xy_cost_fn = get_cost(xy_cost, backend=backend, **kwargs) + else: + xy_data = None + xy_cost_fn = None + + if contains_quadratic: + xx_data = cls._extract_data(adata, attr=xx_attr, key=xx_key) + xx_cost_fn = get_cost(xx_cost, backend=backend, **kwargs) + else: + xx_data = None + xx_cost_fn = None + + conditions_data = ( + cls._extract_data(adata, attr=conditions_attr, key=conditions_key) if contains_condition else None # type: ignore[arg-type] # noqa:E501 + ) + return cls(xy=xy_data, xx=xx_data, a=a, b=b, conditions=conditions_data, cost_xy=xy_cost_fn, cost_xx=xx_cost_fn) + + +class DistributionCollection(dict[K, DistributionContainer]): + """Collection of distributions.""" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}{list(self.keys())}" + + def __str__(self) -> str: + return repr(self) diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index 51c996946..161b67f3b 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -21,7 +21,7 @@ from moscot._types import ArrayLike, Device_t from moscot.backends.ott import GWSolver, SinkhornSolver from moscot.backends.ott._utils import alpha_to_fused_penalty -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.solver import O, OTSolver from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL, Geom_t @@ -321,7 +321,7 @@ def test_push( out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), xy=(x, y)) p = out.push(a, scale_by_marginals=False) - assert isinstance(out, BaseSolverOutput) + assert isinstance(out, BaseDiscreteSolverOutput) assert isinstance(p, jnp.ndarray) if batched: assert p.shape == (out.shape[1], ndim) @@ -347,7 +347,7 @@ def test_pull( out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy)) p = out.pull(b, scale_by_marginals=False) - assert isinstance(out, BaseSolverOutput) + assert isinstance(out, BaseDiscreteSolverOutput) assert isinstance(p, jnp.ndarray) if batched: assert p.shape == (out.shape[0], ndim) diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index 80d7637c2..431a87d93 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -12,7 +12,7 @@ from anndata import AnnData from moscot.backends.ott.output import GraphOTTOutput, OTTOutput -from moscot.base.output import BaseSolverOutput, MatrixSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput, MatrixSolverOutput from moscot.base.problems import OTProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL, Geom_t, MockSolverOutput @@ -27,7 +27,7 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): y={"attr": "X"}, ).solve(epsilon=5e-1, alpha=0.5) - assert isinstance(prob.solution, BaseSolverOutput) + assert isinstance(prob.solution, BaseDiscreteSolverOutput) @pytest.mark.fast() def test_output(self, adata_x: AnnData, x: Geom_t): @@ -160,7 +160,7 @@ def test_set_solution(self, adata_x: AnnData, adata_y: AnnData, clazz: type): prob = prob.set_solution(solution, cost=42, converged=True) assert prob.stage == "solved" - assert isinstance(prob.solution, BaseSolverOutput) + assert isinstance(prob.solution, BaseDiscreteSolverOutput) assert prob.solution.shape == prob.shape assert prob.solution.cost == 42 assert prob.solution.converged diff --git a/tests/problems/conftest.py b/tests/problems/conftest.py index fd2c76c4d..99ae64984 100644 --- a/tests/problems/conftest.py +++ b/tests/problems/conftest.py @@ -257,3 +257,10 @@ def marginal_keys(request): "tau_a": "tau_a", "tau_b": "tau_b", } + +neurallin_cond_args_1 = { + "batch_size": 8, + "seed": 0, + "iterations": 2, + "valid_freq": 4, +} diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index f5199877a..b083ee745 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -9,7 +9,7 @@ from anndata import AnnData from moscot.backends.ott._utils import alpha_to_fused_penalty -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.problems.cross_modality import TranslationProblem from tests.problems.conftest import ( fgw_args_1, @@ -122,7 +122,7 @@ def test_solve_balanced( tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) for key, subsol in tp.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys assert tp[key].solution.rank == rank diff --git a/tests/problems/generic/test_conditional_neural_problem.py b/tests/problems/generic/test_conditional_neural_problem.py new file mode 100644 index 000000000..15dba8146 --- /dev/null +++ b/tests/problems/generic/test_conditional_neural_problem.py @@ -0,0 +1,86 @@ +import optax +import pytest + +import numpy as np +from ott.geometry import costs + +import anndata as ad + +from moscot.base.output import BaseDiscreteSolverOutput +from moscot.base.problems import CondOTProblem +from moscot.problems.generic import GENOTLinProblem # type: ignore[attr-defined] +from moscot.utils.tagged_array import DistributionCollection, DistributionContainer +from tests._utils import ATOL, RTOL +from tests.problems.conftest import neurallin_cond_args_1 + + +class TestGENOTLinProblem: + @pytest.mark.fast() + def test_prepare(self, adata_time: ad.AnnData): + problem = GENOTLinProblem(adata=adata_time) + problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + assert isinstance(problem, CondOTProblem) + assert isinstance(problem.distributions, DistributionCollection) + assert list(problem.distributions.keys()) == [0, 1, 2] + + container = problem.distributions[0] + n_obs_0 = adata_time[adata_time.obs["time"] == 0].n_obs + assert isinstance(container, DistributionContainer) + assert isinstance(container.xy, np.ndarray) + assert container.xy.shape == (n_obs_0, 50) + assert container.xx is None + assert isinstance(container.conditions, np.ndarray) + assert container.conditions.shape == (n_obs_0, 1) + assert isinstance(container.a, np.ndarray) + assert container.a.shape == (n_obs_0,) + assert isinstance(container.b, np.ndarray) + assert container.b.shape == (n_obs_0,) + assert isinstance(container.cost_xy, costs.SqEuclidean) + assert container.cost_xx is None + + @pytest.mark.parametrize("train_size", [0.9, 1.0]) + def test_solve_balanced_no_baseline(self, adata_time: ad.AnnData, train_size: float): # type: ignore[no-untyped-def] # noqa: E501 + problem = GENOTLinProblem(adata=adata_time) + problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + problem = problem.solve(train_size=train_size, **neurallin_cond_args_1) + assert isinstance(problem.solution, BaseDiscreteSolverOutput) + + def test_reproducibility(self, adata_time: ad.AnnData): + cond_zero_mask = np.array(adata_time.obs["time"] == 0) + pc_tzero = adata_time[cond_zero_mask].obsm["X_pca"] + problem_one = GENOTLinProblem(adata=adata_time) + problem_one = problem_one.prepare( + key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}, seed=0 + ) + problem_one = problem_one.solve(**neurallin_cond_args_1) + problem_two = GENOTLinProblem(adata=adata_time) + problem_two = problem_two.prepare( + key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}, seed=0 + ) + problem_two = problem_two.solve(**neurallin_cond_args_1) + assert np.allclose( + problem_one.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(), 1))), + problem_two.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(), 1))), + rtol=RTOL, + atol=ATOL, + ) + + # def test_pass_arguments(self, adata_time: ad.AnnData): # TODO(ilan-gold) implement this once the OTT PR is settled + # problem = GENOTLinProblem(adata=adata_time) + # adata_time = adata_time[adata_time.obs["time"].isin((0, 1))] + # problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + # problem = problem.solve(**neurallin_cond_args_1) + + # solver = problem.solver._solver + # for arg, val in neurallin_cond_args_1.items(): + # assert hasattr(solver, val) + # el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) + # assert el == neurallin_cond_args_1[arg] + + def test_pass_custom_optimizers(self, adata_time: ad.AnnData): + problem = GENOTLinProblem(adata=adata_time) + adata_time = adata_time[adata_time.obs["time"].isin((0, 1))] + problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) + custom_opt = optax.adagrad(1e-4) + + problem = problem.solve(iterations=2, optimizer=custom_opt) diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index d2724e80c..35a5efca2 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -12,7 +12,7 @@ from moscot._types import CostKwargs_t from moscot.backends.ott._utils import alpha_to_fused_penalty -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import OTProblem from moscot.problems.generic import FGWProblem from tests._utils import _assert_marginals_set @@ -89,7 +89,7 @@ def test_solve_balanced(self, adata_space_rotate: AnnData): problem = problem.solve(alpha=0.5, epsilon=eps) for key, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys # assert that prior and posterior marginals same assert np.allclose(subsol.a, problem[key].a, atol=1e-5) diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 1ed381af8..ccee4323f 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -11,7 +11,7 @@ from anndata import AnnData from moscot._types import CostKwargs_t -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import OTProblem from moscot.problems.generic import GWProblem from tests._utils import _assert_marginals_set @@ -72,7 +72,7 @@ def test_solve_balanced(self, adata_space_rotate: AnnData): # type: ignore[no-u problem = problem.solve(epsilon=eps) for key, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys assert problem[key].solver._problem.geom_xy is None # assert prior and posterior marginals are the same diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 64a9bf249..d75d37f2d 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -10,7 +10,7 @@ from anndata import AnnData -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import OTProblem from moscot.problems.generic import SinkhornProblem from tests._utils import _assert_marginals_set @@ -55,7 +55,7 @@ def test_solve_balanced(self, adata_time: AnnData, marginal_keys): problem = problem.solve(epsilon=eps) for key, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys assert subsol.converged assert np.allclose(subsol.a, problem[key].a, atol=1e-5) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index a22d5cd35..e7d35c82c 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -9,7 +9,7 @@ from anndata import AnnData from moscot.backends.ott._utils import alpha_to_fused_penalty -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import BirthDeathProblem from moscot.problems.spatiotemporal import SpatioTemporalProblem from tests._utils import ATOL, RTOL @@ -57,7 +57,7 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData): problem = problem.solve(alpha=alpha, epsilon=eps) for key, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys @pytest.mark.skip(reason="unbalanced does not work yet") diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index fed68cf6f..fa79fbb74 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -8,7 +8,7 @@ from anndata import AnnData from moscot.backends.ott._utils import alpha_to_fused_penalty -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import BirthDeathProblem from moscot.problems.time import LineageProblem from tests._utils import ATOL, RTOL @@ -58,7 +58,7 @@ def test_solve_balanced(self, adata_time_barcodes: AnnData): problem = problem.solve(epsilon=eps) for _, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) def test_solve_unbalanced(self, adata_time_barcodes: AnnData): taus = [9e-1, 1e-2] diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 43923f285..5a4f92df7 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -13,7 +13,7 @@ from anndata import AnnData from moscot.backends.ott.output import GraphOTTOutput -from moscot.base.output import BaseSolverOutput +from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems import BirthDeathProblem from moscot.problems.time import TemporalProblem from moscot.utils.tagged_array import Tag, TaggedArray @@ -63,7 +63,7 @@ def test_solve_balanced(self, adata_time: AnnData, callback: Optional[str]): assert isinstance(problem[0, 1].xy.cost, costs.Cosine) for key, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys def test_solve_unbalanced(self, adata_time: AnnData): @@ -361,7 +361,7 @@ def test_graph_construction_callback(self, adata_time: AnnData, callback_kwargs: assert problem[0, 1].xy.cost == "geodesic" for key, subsol in problem.solutions.items(): - assert isinstance(subsol, BaseSolverOutput) + assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys if "n_neighbors" in callback_kwargs: diff --git a/tests/solvers/test_base_solver.py b/tests/solvers/test_base_solver.py index fd503f179..046ef6e0e 100644 --- a/tests/solvers/test_base_solver.py +++ b/tests/solvers/test_base_solver.py @@ -11,7 +11,7 @@ from tests._utils import ATOL, RTOL, MockSolverOutput -class TestBaseSolverOutput: +class TestBaseDiscreteSolverOutput: @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("threshold", [0.0, 1e-1, 1.0]) @pytest.mark.parametrize("shape", [(7, 2), (91, 103)])