diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index fc623ac8e..1ed7a0cc8 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -1,6 +1,7 @@ from typing import Any, Literal, Optional, Tuple, Union import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp import scipy.sparse as sp from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud @@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float: return (1 - alpha) / alpha +def densify(arr: ArrayLike) -> jax.Array: + """If the input is sparse, convert it t dense. + + Parameters + ---------- + arr + Array to check. + + Returns + ------- + dense :mod:`jax` array. + """ + if sp.issparse(arr): + arr = arr.A # type: ignore[attr-defined] + elif isinstance(arr, jesp.BCOO): + arr = arr.todense() + return jnp.asarray(arr) + + def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: """Ensure that an array is 2-dimensional. @@ -81,9 +101,6 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: ------- 2-dimensional :mod:`jax` array. """ - if sp.issparse(arr): - arr = arr.A # type: ignore[attr-defined] - arr = jnp.asarray(arr) if reshape and arr.ndim == 1: return jnp.reshape(arr, (-1, 1)) if arr.ndim != 2: @@ -91,6 +108,13 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: return arr +def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO: + """ If the input is a scipy sparse matrix, convert it to a jax BCOO.""" + if sp.issparse(arr): + return jesp.BCOO.from_scipy_sparse(arr) + return arr + + def _instantiate_geodesic_cost( arr: jax.Array, problem_shape: Tuple[int, int], diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index f8130eab2..9d280d892 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -16,6 +16,8 @@ _instantiate_geodesic_cost, alpha_to_fused_penalty, check_shapes, + convert_scipy_sparse, + densify, ensure_2d, ) from moscot.backends.ott.output import GraphOTTOutput, OTTOutput @@ -76,8 +78,8 @@ def _create_geometry( if not isinstance(cost_fn, costs.CostFn): raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.") - y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True) - x = ensure_2d(x.data_src, reshape=True) + y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True)) + x = densify(ensure_2d(x.data_src, reshape=True)) if y is not None and x.shape[1] != y.shape[1]: raise ValueError( f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`." @@ -94,6 +96,8 @@ def _create_geometry( ) arr = ensure_2d(x.data_src, reshape=False) + arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr) + if x.is_cost_matrix: return geometry.Geometry( cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost diff --git a/tests/backends/ott/test_backend_utils.py b/tests/backends/ott/test_backend_utils.py new file mode 100644 index 000000000..394d77038 --- /dev/null +++ b/tests/backends/ott/test_backend_utils.py @@ -0,0 +1,53 @@ +from typing import Optional, Tuple, Type, Union + +import pytest + +import jax +import jax.numpy as jnp +import numpy as np +from ott.geometry import costs +from ott.geometry.geometry import Geometry +from ott.geometry.low_rank import LRCGeometry +from ott.geometry.pointcloud import PointCloud +from ott.problems.linear.linear_problem import LinearProblem +from ott.problems.quadratic import quadratic_problem +from ott.problems.quadratic.quadratic_problem import QuadraticProblem +from ott.solvers.linear import solve as sinkhorn +from ott.solvers.linear.sinkhorn import Sinkhorn +from ott.solvers.linear.sinkhorn_lr import LRSinkhorn +from ott.solvers.quadratic.gromov_wasserstein import GromovWasserstein +from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein + +from moscot._types import ArrayLike, Device_t +from moscot.backends.ott import GWSolver, SinkhornSolver +from moscot.backends.ott._utils import alpha_to_fused_penalty +from moscot.base.output import BaseSolverOutput +from moscot.base.solver import O, OTSolver +from moscot.backends.ott._utils import _instantiate_geodesic_cost + +import jax.experimental.sparse as jesp +import scipy.sparse as sp +import networkx as nx +from networkx.generators import balanced_tree, random_graphs + + +class TestBackendUtils: + + @staticmethod + def test_instantiate_geodesic_cost(): + m, n = 10, 10 + problem_shape = 10, 10 + g = sp.rand(m, n, 0.1, dtype=np.float64) + g = jesp.BCOO.from_scipy_sparse(g) + geom = _instantiate_geodesic_cost(g, problem_shape, 1.0, False) + assert isinstance(geom, Geometry) + with pytest.raises(ValueError, match="Expected `x` to have"): + _instantiate_geodesic_cost(g, problem_shape, 1.0, True) + geom = _instantiate_geodesic_cost(g, (5, 5), 1.0, True) + + + + + + +