Skip to content

Commit

Permalink
first draft of geodesic sparse need to add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Mar 26, 2024
1 parent f3c54be commit ceb2f08
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 5 deletions.
30 changes: 27 additions & 3 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Literal, Optional, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
Expand Down Expand Up @@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
return (1 - alpha) / alpha


def densify(arr: ArrayLike) -> jax.Array:
"""If the input is sparse, convert it t dense.
Parameters
----------
arr
Array to check.
Returns
-------
dense :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.A # type: ignore[attr-defined]
elif isinstance(arr, jesp.BCOO):
arr = arr.todense()
return jnp.asarray(arr)


def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
"""Ensure that an array is 2-dimensional.
Expand All @@ -81,16 +101,20 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
-------
2-dimensional :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.A # type: ignore[attr-defined]
arr = jnp.asarray(arr)
if reshape and arr.ndim == 1:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr


def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
""" If the input is a scipy sparse matrix, convert it to a jax BCOO."""
if sp.issparse(arr):
return jesp.BCOO.from_scipy_sparse(arr)
return arr


def _instantiate_geodesic_cost(
arr: jax.Array,
problem_shape: Tuple[int, int],
Expand Down
8 changes: 6 additions & 2 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
_instantiate_geodesic_cost,
alpha_to_fused_penalty,
check_shapes,
convert_scipy_sparse,
densify,
ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
Expand Down Expand Up @@ -76,8 +78,8 @@ def _create_geometry(
if not isinstance(cost_fn, costs.CostFn):
raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.")

y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True)
x = ensure_2d(x.data_src, reshape=True)
y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True))
x = densify(ensure_2d(x.data_src, reshape=True))
if y is not None and x.shape[1] != y.shape[1]:
raise ValueError(
f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`."
Expand All @@ -94,6 +96,8 @@ def _create_geometry(
)

arr = ensure_2d(x.data_src, reshape=False)
arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr)

if x.is_cost_matrix:
return geometry.Geometry(
cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
Expand Down
53 changes: 53 additions & 0 deletions tests/backends/ott/test_backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Optional, Tuple, Type, Union

import pytest

import jax
import jax.numpy as jnp
import numpy as np
from ott.geometry import costs
from ott.geometry.geometry import Geometry
from ott.geometry.low_rank import LRCGeometry
from ott.geometry.pointcloud import PointCloud
from ott.problems.linear.linear_problem import LinearProblem
from ott.problems.quadratic import quadratic_problem
from ott.problems.quadratic.quadratic_problem import QuadraticProblem
from ott.solvers.linear import solve as sinkhorn
from ott.solvers.linear.sinkhorn import Sinkhorn
from ott.solvers.linear.sinkhorn_lr import LRSinkhorn
from ott.solvers.quadratic.gromov_wasserstein import GromovWasserstein
from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein

from moscot._types import ArrayLike, Device_t
from moscot.backends.ott import GWSolver, SinkhornSolver
from moscot.backends.ott._utils import alpha_to_fused_penalty
from moscot.base.output import BaseSolverOutput
from moscot.base.solver import O, OTSolver
from moscot.backends.ott._utils import _instantiate_geodesic_cost

import jax.experimental.sparse as jesp
import scipy.sparse as sp
import networkx as nx
from networkx.generators import balanced_tree, random_graphs


class TestBackendUtils:

@staticmethod
def test_instantiate_geodesic_cost():
m, n = 10, 10
problem_shape = 10, 10
g = sp.rand(m, n, 0.1, dtype=np.float64)
g = jesp.BCOO.from_scipy_sparse(g)
geom = _instantiate_geodesic_cost(g, problem_shape, 1.0, False)
assert isinstance(geom, Geometry)
with pytest.raises(ValueError, match="Expected `x` to have"):
_instantiate_geodesic_cost(g, problem_shape, 1.0, True)
geom = _instantiate_geodesic_cost(g, (5, 5), 1.0, True)







0 comments on commit ceb2f08

Please sign in to comment.