Skip to content

Commit

Permalink
Add Sparse Geodesic Cost Support (#677)
Browse files Browse the repository at this point in the history
* set ottjax version first and the tests

* recreate solution files with new ottjax version so it doesn't fail

* update tests with new version, specify inner_iterations

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add comments on issues

* first draft of geodesic sparse need to add more tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* set to new ottjax commit

* revert old commit that fixes the lr case

* set new version of ottjax instead of commit

* fix and modify the geodesic cost tests

* typo fix

* fix doc and add tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* still densify for other costs

* fix docs

* add the final tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Dominik Klein <[email protected]>
  • Loading branch information
3 people authored May 20, 2024
1 parent 4ba5f83 commit 1b951fe
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax>=0.4.5",
"ott-jax>=0.4.6",
"cloudpickle>=2.2.0",
"rich>=13.5",
]
Expand Down
30 changes: 27 additions & 3 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Literal, Optional, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
Expand Down Expand Up @@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
return (1 - alpha) / alpha


def densify(arr: ArrayLike) -> jax.Array:
"""If the input is sparse, convert it to dense.
Parameters
----------
arr
Array to check.
Returns
-------
dense :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.A # type: ignore[attr-defined]
elif isinstance(arr, jesp.BCOO):
arr = arr.todense()
return jnp.asarray(arr)


def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
"""Ensure that an array is 2-dimensional.
Expand All @@ -81,16 +101,20 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
-------
2-dimensional :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.A # type: ignore[attr-defined]
arr = jnp.asarray(arr)
if reshape and arr.ndim == 1:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr


def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
"""If the input is a scipy sparse matrix, convert it to a jax BCOO."""
if sp.issparse(arr):
return jesp.BCOO.from_scipy_sparse(arr)
return arr


def _instantiate_geodesic_cost(
arr: jax.Array,
problem_shape: Tuple[int, int],
Expand Down
8 changes: 6 additions & 2 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
_instantiate_geodesic_cost,
alpha_to_fused_penalty,
check_shapes,
convert_scipy_sparse,
densify,
ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
Expand Down Expand Up @@ -76,8 +78,8 @@ def _create_geometry(
if not isinstance(cost_fn, costs.CostFn):
raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.")

y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True)
x = ensure_2d(x.data_src, reshape=True)
y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True))
x = densify(ensure_2d(x.data_src, reshape=True))
if y is not None and x.shape[1] != y.shape[1]:
raise ValueError(
f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`."
Expand All @@ -94,6 +96,8 @@ def _create_geometry(
)

arr = ensure_2d(x.data_src, reshape=False)
arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr)

if x.is_cost_matrix:
return geometry.Geometry(
cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
Expand Down
11 changes: 6 additions & 5 deletions src/moscot/utils/tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _extract_data(
*,
attr: Literal["X", "obsp", "obsm", "layers", "uns"],
key: Optional[str] = None,
densify: bool = False,
) -> ArrayLike:
modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]"
data = getattr(adata, attr)
Expand All @@ -67,7 +68,7 @@ def _extract_data(
except IndexError:
raise IndexError(f"Unable to fetch data from `{modifier}`.") from None

if sp.issparse(data):
if sp.issparse(data) and densify:
logger.warning(f"Densifying data in `{modifier}`")
data = data.A
if data.ndim != 2:
Expand Down Expand Up @@ -103,7 +104,7 @@ def from_adata(
"""Create tagged array from :class:`~anndata.AnnData`.
.. warning::
Sparse arrays will be always densified.
Sparse arrays will be densified except when ``tag = 'graph'``.
Parameters
----------
Expand Down Expand Up @@ -137,13 +138,13 @@ def from_adata(
if tag == Tag.GRAPH:
if cost == "geodesic":
dist_key = f"{dist_key[0]}_{dist_key[1]}" if isinstance(dist_key, tuple) else dist_key
data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}")
data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}", densify=False)
return cls(data_src=data, tag=Tag.GRAPH, cost="geodesic")
raise ValueError(f"Expected `cost=geodesic`, found `{cost}`.")
if tag == Tag.COST_MATRIX:
if cost == "custom": # our custom cost functions
modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]"
data = cls._extract_data(adata, attr=attr, key=key)
data = cls._extract_data(adata, attr=attr, key=key, densify=True)
if np.any(data < 0):
raise ValueError(f"Cost matrix in `{modifier}` contains negative values.")
return cls(data_src=data, tag=Tag.COST_MATRIX, cost=None)
Expand All @@ -153,7 +154,7 @@ def from_adata(
return cls(data_src=cost_matrix, tag=Tag.COST_MATRIX, cost=None)

# tag is either a point cloud or a kernel
data = cls._extract_data(adata, attr=attr, key=key)
data = cls._extract_data(adata, attr=attr, key=key, densify=True)
cost_fn = get_cost(cost, backend=backend, **kwargs)
return cls(data_src=data, tag=tag, cost=cost_fn)

Expand Down
23 changes: 23 additions & 0 deletions tests/backends/ott/test_backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

import jax.experimental.sparse as jesp
import numpy as np
import scipy.sparse as sp
from ott.geometry.geometry import Geometry

from moscot.backends.ott._utils import _instantiate_geodesic_cost


class TestBackendUtils:

@staticmethod
def test_instantiate_geodesic_cost():
m, n = 10, 10
problem_shape = 10, 10
g = sp.rand(m, n, 0.1, dtype=np.float64)
g = jesp.BCOO.from_scipy_sparse(g)
geom = _instantiate_geodesic_cost(g, problem_shape, 1.0, False)
assert isinstance(geom, Geometry)
with pytest.raises(ValueError, match="Expected `x` to have"):
_instantiate_geodesic_cost(g, problem_shape, 1.0, True)
geom = _instantiate_geodesic_cost(g, (5, 5), 1.0, True)
19 changes: 14 additions & 5 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler

import scanpy as sc
Expand Down Expand Up @@ -127,7 +128,8 @@ def test_solve_unbalanced(self, adata_space_rotate: AnnData):
assert np.allclose(*(sol.cost for sol in ap.solutions.values()), rtol=1e-5, atol=1e-5)

@pytest.mark.parametrize("key", ["connectivities", "distances"])
def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str):
@pytest.mark.parametrize("dense_input", [True, False])
def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str, dense_input: bool):
batch_column = "batch"
unique_batches = adata_space_rotate.obs[batch_column].unique()

Expand All @@ -140,13 +142,20 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str):
)[0]
adata_subset = adata_space_rotate[indices]
sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X_pca")
dfs.append(
df = (
pd.DataFrame(
index=adata_subset.obs_names,
columns=adata_subset.obs_names,
data=adata_subset.obsp["connectivities"].A.astype("float64"),
data=adata_subset.obsp[key].A.astype("float64"),
)
if dense_input
else (
adata_subset.obsp[key].astype("float64"),
adata_subset.obs_names.to_series(),
adata_subset.obs_names.to_series(),
)
)
dfs.append(df)

ap: AlignmentProblem = AlignmentProblem(adata=adata_space_rotate)
ap = ap.prepare(batch_key=batch_column, joint_attr={"attr": "obsm", "key": "X_pca"})
Expand All @@ -157,14 +166,14 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str):

ta = ap[("0", "1")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"

ta = ap[("1", "2")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand Down
38 changes: 27 additions & 11 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler

import anndata as ad
Expand Down Expand Up @@ -134,7 +135,8 @@ def test_solve_balanced(

@pytest.mark.parametrize("key", ["connectivities", "distances"])
@pytest.mark.parametrize("geodesic_y", [True, False])
def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bool):
@pytest.mark.parametrize("dense_input", [True, False])
def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bool, dense_input: bool):
adataref, adatasp = _adata_spatial_split(adata_mapping)

batch_column = "batch"
Expand All @@ -146,20 +148,34 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo
adata_spatial_subset = adatasp[indices]
adata_subset = ad.concat([adata_spatial_subset, adataref])
sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X")
dfs.append(
df = (
pd.DataFrame(
index=adata_subset.obs_names,
columns=adata_subset.obs_names,
data=adata_subset.obsp["connectivities"].A.astype("float64"),
data=adata_subset.obsp[key].A.astype("float64"),
)
if dense_input
else (
adata_subset.obsp[key].astype("float64"),
adata_subset.obs_names.to_series(),
adata_subset.obs_names.to_series(),
)
)
dfs.append(df)

if geodesic_y:
sc.pp.neighbors(adataref, n_neighbors=15, use_rep="X")
df_y = pd.DataFrame(
index=adataref.obs_names,
columns=adataref.obs_names,
data=adataref.obsp["connectivities"].A.astype("float64"),
df_y = (
pd.DataFrame(
index=adataref.obs_names,
columns=adataref.obs_names,
data=adataref.obsp[key].A.astype("float64"),
)
if dense_input
else (
adataref.obsp[key].astype("float64"),
adataref.obs_names.to_series(),
)
)

mp: MappingProblem = MappingProblem(adataref, adatasp)
Expand All @@ -175,28 +191,28 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo

ta = mp[("1", "ref")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
if geodesic_y:
ta = mp[("1", "ref")].y
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"

ta = mp[("2", "ref")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
if geodesic_y:
ta = mp[("2", "ref")].y
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand Down
17 changes: 13 additions & 4 deletions tests/problems/time/test_temporal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
import scipy.sparse as sp
from ott.geometry import costs, epsilon_scheduler
from scipy.sparse import csr_matrix

Expand Down Expand Up @@ -249,7 +250,8 @@ def test_result_compares_to_wot(self, gt_temporal_adata: AnnData):
np.array(tp[key_1, key_3].solution.transport_matrix),
)

def test_geodesic_cost_set_xy_cost_dense(self, adata_time):
@pytest.mark.parametrize("dense_input", [True, False])
def test_geodesic_cost_set_xy_cost(self, adata_time, dense_input):
# TODO(@MUCDK) add test for failure case
tp = TemporalProblem(adata_time)
tp = tp.prepare("time", joint_attr="X_pca")
Expand All @@ -264,20 +266,27 @@ def test_geodesic_cost_set_xy_cost_dense(self, adata_time):
indices = np.where((adata_time.obs[batch_column] == batch1) | (adata_time.obs[batch_column] == batch2))[0]
adata_subset = adata_time[indices]
sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X_pca")
dfs.append(
df = (
pd.DataFrame(
index=adata_subset.obs_names,
columns=adata_subset.obs_names,
data=adata_subset.obsp["connectivities"].A.astype("float64"),
)
if dense_input
else (
adata_subset.obsp["connectivities"].astype("float64"),
adata_subset.obs_names.to_series(),
adata_subset.obs_names.to_series(),
)
)
dfs.append(df)

tp[0, 1].set_graph_xy(dfs[0], cost="geodesic")
tp = tp.solve(max_iterations=2, lse_mode=False)

ta = tp[0, 1].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray)
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand All @@ -287,7 +296,7 @@ def test_geodesic_cost_set_xy_cost_dense(self, adata_time):

ta = tp[1, 2].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray)
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand Down

0 comments on commit 1b951fe

Please sign in to comment.