Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 26, 2024
1 parent 37c3757 commit 5b4afdd
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 96 deletions.
21 changes: 3 additions & 18 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,11 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import (
OTTOutput,
OTTNeuralOutput,
GraphOTTOutput
)
from moscot.backends.ott.solver import (
GWSolver,
SinkhornSolver,
GENOTLinSolver,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = [
"OTTOutput",
"GWSolver",
"SinkhornSolver",
"OTTNeuralOutput",
"sinkhorn_divergence",
"GENOTLinSolver"
]
__all__ = ["OTTOutput", "GWSolver", "SinkhornSolver", "OTTNeuralOutput", "sinkhorn_divergence", "GENOTLinSolver"]


register_cost("euclidean", backend="ott")(costs.Euclidean)
Expand Down
14 changes: 2 additions & 12 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from functools import partial
from typing import (
Any,
Dict,
Mapping,
Optional,
Tuple,
Type,
Union,
Literal
)

import optax
from typing import Any, Literal, Optional, Tuple, Union


import jax
import jax.numpy as jnp
Expand Down
37 changes: 17 additions & 20 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union

import jaxlib.xla_extension as xla_ext

import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.problems.linear import potentials
from ott.neural.flow_models.genot import (
GENOTBase, # TODO(ilan-gold): will neeed to update for ICNN's
)
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
from ott.neural.flow_models.genot import GENOTBase # TODO(ilan-gold): will neeed to update for ICNN's

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from moscot._types import ArrayLike, Device_t
from moscot.backends.ott._utils import get_nearest_neighbors
Expand Down Expand Up @@ -247,7 +246,7 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102
class OTTNeuralOutput(BaseNeuralOutput):
def __init__(self, model: GENOTBase):
self._model = model

def _project_transport_matrix(
self,
src_dist: ArrayLike,
Expand Down Expand Up @@ -289,7 +288,7 @@ def _project_transport_matrix(
if save_transport_matrix:
self._inverse_transport_matrix = tm
return tm

def project_transport_matrix( # type:ignore[override]
self,
src_cells: ArrayLike,
Expand Down Expand Up @@ -341,13 +340,9 @@ def project_transport_matrix( # type:ignore[override]
The projected transport matrix.
"""
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
push = self.push if condition is None else lambda x : self.push(x, condition)
pull = self.pull if condition is None else lambda x : self.pull(x, condition)
func, src_dist, tgt_dist = (
(push, src_cells, tgt_cells)
if forward
else (pull, tgt_cells, src_cells)
)
push = self.push if condition is None else lambda x: self.push(x, condition)
pull = self.pull if condition is None else lambda x: self.pull(x, condition)
func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells)
return self._project_transport_matrix(
src_dist=src_dist,
tgt_dist=tgt_dist,
Expand All @@ -359,7 +354,7 @@ def project_transport_matrix( # type:ignore[override]
length_scale=length_scale,
seed=seed,
)

def push(self, x: ArrayLike, cond: ArrayLike | None = None) -> ArrayLike: # type: ignore[override]
"""Push distribution `x` conditioned on condition `cond`.
Expand Down Expand Up @@ -395,24 +390,26 @@ def pull(self, x: ArrayLike, cond: ArrayLike | None = None) -> ArrayLike: # typ
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
return self._apply(x, cond=cond, forward=False)

def _apply(self, x: ArrayLike, forward: bool, cond: ArrayLike | None = None) -> ArrayLike:
return self._model.transport(x, condition=cond, forward=forward)

@property
def shape(self) -> Tuple[int, int]:
"""%(shape)s."""
raise NotImplementedError()

def to(
self,
device: Optional[Device_t] = None,
) -> "OTTNeuralOutput":
"""Transfer the output to another device or change its data type.
Parameters
----------
device
If not `None`, the output will be transferred to `device`.
Returns
-------
The output on a saved on `device`.
Expand All @@ -432,8 +429,8 @@ def to(

# out = jax.device_put(self._model, device)
# return OTTNeuralOutput(out)
return self #TODO(ilan-gold) move model to device
return self # TODO(ilan-gold) move model to device

@property
def converged(self) -> bool:
"""%(converged)s."""
Expand Down
40 changes: 22 additions & 18 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,43 @@
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
)

import optax
from torch.utils.data import DataLoader, RandomSampler

import jax
import jax.numpy as jnp
import numpy as np
import optax
import scipy.sparse as sp
from ott.geometry import costs, epsilon_scheduler, geometry, pointcloud, geodesic
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud
from ott.neural.data.datasets import ConditionalOTDataset, OTDataset
from ott.neural.flow_models.genot import GENOTLin
from ott.neural.data.datasets import OTDataset, ConditionalOTDataset
from ott.neural.flow_models.models import VelocityField
from ott.neural.flow_models.samplers import uniform_sampler
from ott.neural.models.base_solver import UnbalancednessHandler, OTMatcherLinear
from ott.neural.models.base_solver import OTMatcherLinear, UnbalancednessHandler
from ott.neural.models.nets import RescalingMLP
from torch.utils.data import DataLoader, RandomSampler
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

from moscot._types import (
ArrayLike,
ProblemKind_t,
QuadInitializer_t,
SinkhornInitializer_t,
)
from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d, _instantiate_geodesic_cost
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
from moscot.backends.ott.output import OTTNeuralOutput, OTTOutput,
from moscot.backends.ott._utils import (
_instantiate_geodesic_cost,
alpha_to_fused_penalty,
check_shapes,
ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTNeuralOutput, OTTOutput
from moscot.base.problems._utils import TimeScalesHeatKernel

from moscot.base.solver import OTSolver
from moscot.costs import get_cost
from moscot.utils.tagged_array import DistributionCollection, TaggedArray
Expand Down Expand Up @@ -500,9 +504,9 @@ def __init__(self, **kwargs: Any) -> None:
self._neural_kwargs = kwargs

@property
def problem_kind(self) -> ProblemKind_t: # noqa: D102
return "linear"
def problem_kind(self) -> ProblemKind_t: # noqa: D102
return "linear"

def _prepare( # type: ignore[override]
self,
distributions: DistributionCollection[K],
Expand Down Expand Up @@ -638,7 +642,7 @@ def _prepare( # type: ignore[override]
)
return ConditionalOTDataset(datasets=train_loaders, seed=seed), ConditionalOTDataset(datasets=validate_loaders, seed=seed)


@staticmethod
def _assert2d(arr: ArrayLike, *, allow_reshape: bool = True) -> jnp.ndarray:
arr: jnp.ndarray = jnp.asarray(arr.A if sp.issparse(arr) else arr) # type: ignore[no-redef, attr-defined] # noqa:E501
Expand Down
6 changes: 1 addition & 5 deletions src/moscot/base/problems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
from moscot.base.problems.birth_death import BirthDeathMixin, BirthDeathProblem
from moscot.base.problems.compound_problem import BaseCompoundProblem, CompoundProblem
from moscot.base.problems.manager import ProblemManager
from moscot.base.problems.problem import (
BaseProblem,
CondOTProblem,
OTProblem,
)
from moscot.base.problems.problem import BaseProblem, CondOTProblem, OTProblem

__all__ = [
"AnalysisMixin",
Expand Down
1 change: 0 additions & 1 deletion src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from anndata import AnnData

from moscot import _constants
from moscot._logging import logger
from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.problems._utils import (
Expand Down
8 changes: 7 additions & 1 deletion src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
from moscot._logging import logger
from moscot._types import ArrayLike, CostFn_t, Device_t, ProblemKind_t
from moscot.base.output import BaseDiscreteSolverOutput, MatrixSolverOutput
from moscot.base.problems._utils import TimeScalesHeatKernel, require_solution, wrap_prepare, wrap_solve
from moscot.base.problems._utils import (
TimeScalesHeatKernel,
require_solution,
wrap_prepare,
wrap_solve,
)
from moscot.base.solver import OTSolver
from moscot.utils.subset_policy import ( # type:ignore[attr-defined]
Policy_t,
Expand Down Expand Up @@ -1025,6 +1030,7 @@ def __repr__(self) -> str:
def __str__(self) -> str:
return repr(self)


class CondOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load
"""
Base class for all optimal transport problems.
Expand Down
6 changes: 2 additions & 4 deletions src/moscot/problems/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from moscot.problems.generic._generic import (
FGWProblem,
GENOTLinProblem,
GWProblem,
SinkhornProblem,
FGWProblem
)

from moscot.problems.generic._mixins import GenericAnalysisMixin

__all__ = [
"FGWProblem"
"SinkhornProblem",
"FGWProblem" "SinkhornProblem",
"GENOTLinProblem",
"GWProblem",
"GenericAnalysisMixin",
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import inspect
import types
from types import MappingProxyType
from typing import Any, Dict, Literal, Mapping, Optional, Set, Tuple, Type, Union
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Type, Union

import optax

from anndata import AnnData

Expand Down Expand Up @@ -31,6 +29,7 @@

__all__ = ["SinkhornProblem", "GWProblem", "GENOTLinProblem", "FGWProblem"]


def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, str]:
if isinstance(z, str):
return {"attr": "obsm", "key": z, "tag": "point_cloud"} # cost handled by handle_cost
Expand Down Expand Up @@ -710,6 +709,7 @@ def _base_problem_type(self) -> Type[B]:
def _valid_policies(self) -> Tuple[Policy_t, ...]:
return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value]


class GENOTLinProblem(CondOTProblem, GenericAnalysisMixin[K, B]):
"""Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems."""

Expand Down
2 changes: 1 addition & 1 deletion tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,4 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"seed": 0,
"iterations": 2,
"valid_freq": 4,
}
}
22 changes: 9 additions & 13 deletions tests/problems/generic/test_conditional_neural_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,12 @@

import anndata as ad

from moscot.backends.ott.output import OTTNeuralOutput
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.problems import CondOTProblem
from moscot.problems.generic import (
GENOTLinProblem, # type: ignore[attr-defined]
)
from moscot.problems.generic import GENOTLinProblem # type: ignore[attr-defined]
from moscot.utils.tagged_array import DistributionCollection, DistributionContainer
from tests._utils import ATOL, RTOL
from tests.problems.conftest import (
neurallin_cond_args_1,
neurallin_cond_args_2,
)
from tests.problems.conftest import neurallin_cond_args_1, neurallin_cond_args_2


class TestGENOTLinProblem:
Expand Down Expand Up @@ -71,8 +65,8 @@ def test_reproducibility(self, adata_time: ad.AnnData):
problem_two = problem_one.prepare("time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"})
problem_two = problem_one.solve(**neurallin_cond_args_1)
assert np.allclose(
problem_one.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(),1))),
problem_two.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(),1))),
problem_one.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(), 1))),
problem_two.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(), 1))),
rtol=RTOL,
atol=ATOL,
)
Expand Down Expand Up @@ -105,9 +99,11 @@ def test_learning_rescaling_factors(self, adata_time: ad.AnnData):
assert isinstance(problem.solution, BaseDiscreteSolverOutput)

array = np.asarray(adata_time.obsm["X_pca"].copy())
cond1 = jnp.ones((array.shape[0],1))
cond2 = jnp.zeros((array.shape[0],1))
learnt_eta_1 = problem.solver.solver.unbalancedness_handler.evaluate_eta(array, cond1) # TODO(ilan-gold): sould this be wrapped?
cond1 = jnp.ones((array.shape[0], 1))
cond2 = jnp.zeros((array.shape[0], 1))
learnt_eta_1 = problem.solver.solver.unbalancedness_handler.evaluate_eta(
array, cond1
) # TODO(ilan-gold): sould this be wrapped?
learnt_xi_1 = problem.solver.solver.unbalancedness_handler.evaluate_xi(array, cond1)
learnt_eta_2 = problem.solver.solver.unbalancedness_handler.evaluate_eta(array, cond2)
learnt_xi_2 = problem.solver.solver.unbalancedness_handler.evaluate_xi(array, cond2)
Expand Down

0 comments on commit 5b4afdd

Please sign in to comment.