Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Sparse Geodesic Cost Support #677

Merged
merged 26 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fd407df
set ottjax version first and the tests
selmanozleyen Mar 20, 2024
a2589da
recreate solution files with new ottjax version so it doesn't fail
selmanozleyen Mar 20, 2024
17c1f1e
update tests with new version, specify inner_iterations
selmanozleyen Mar 26, 2024
bdac69b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
f3c54be
add comments on issues
selmanozleyen Mar 26, 2024
ceb2f08
first draft of geodesic sparse need to add more tests
selmanozleyen Mar 26, 2024
bf25bb7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
85d780c
Merge branch 'main' into add/sparse-geodesic
selmanozleyen Mar 26, 2024
d41f2a0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 26, 2024
5b83621
set to new ottjax commit
selmanozleyen Apr 3, 2024
d85fa68
Merge branch 'main' into add/sparse-geodesic
selmanozleyen Apr 3, 2024
afcc83f
Merge branch 'main' into add/sparse-geodesic
selmanozleyen Apr 16, 2024
40c7966
Merge branch 'main' into add/sparse-geodesic
selmanozleyen May 7, 2024
be09dd7
revert old commit that fixes the lr case
selmanozleyen May 7, 2024
9672916
set new version of ottjax instead of commit
selmanozleyen May 7, 2024
1781ea4
fix and modify the geodesic cost tests
selmanozleyen May 7, 2024
93c39cb
Merge branch 'main' into add/sparse-geodesic
selmanozleyen May 10, 2024
b5d5309
typo fix
selmanozleyen May 13, 2024
2bc3f40
fix doc and add tests
selmanozleyen May 13, 2024
3b69af5
Merge branch 'main' into add/sparse-geodesic
selmanozleyen May 13, 2024
4b52324
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 13, 2024
34d0520
still densify for other costs
selmanozleyen May 13, 2024
cc21b1f
fix docs
selmanozleyen May 13, 2024
0811e82
add the final tests
selmanozleyen May 14, 2024
59a4656
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 14, 2024
b054584
Merge branch 'main' into add/sparse-geodesic
MUCDK May 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax>=0.4.5",
"ott-jax>=0.4.6",
"cloudpickle>=2.2.0",
"rich>=13.5",
]
Expand Down
30 changes: 27 additions & 3 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Literal, Optional, Tuple, Union

import jax
import jax.experimental.sparse as jesp
import jax.numpy as jnp
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
Expand Down Expand Up @@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float:
return (1 - alpha) / alpha


def densify(arr: ArrayLike) -> jax.Array:
"""If the input is sparse, convert it to dense.

Parameters
----------
arr
Array to check.

Returns
-------
dense :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.A # type: ignore[attr-defined]
elif isinstance(arr, jesp.BCOO):
arr = arr.todense()
return jnp.asarray(arr)


def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
"""Ensure that an array is 2-dimensional.

Expand All @@ -81,16 +101,20 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
-------
2-dimensional :mod:`jax` array.
"""
if sp.issparse(arr):
arr = arr.A # type: ignore[attr-defined]
arr = jnp.asarray(arr)
if reshape and arr.ndim == 1:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr


def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
"""If the input is a scipy sparse matrix, convert it to a jax BCOO."""
if sp.issparse(arr):
return jesp.BCOO.from_scipy_sparse(arr)
return arr


def _instantiate_geodesic_cost(
arr: jax.Array,
problem_shape: Tuple[int, int],
Expand Down
8 changes: 6 additions & 2 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
_instantiate_geodesic_cost,
alpha_to_fused_penalty,
check_shapes,
convert_scipy_sparse,
densify,
ensure_2d,
)
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
Expand Down Expand Up @@ -76,8 +78,8 @@ def _create_geometry(
if not isinstance(cost_fn, costs.CostFn):
raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.")

y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True)
x = ensure_2d(x.data_src, reshape=True)
y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True))
x = densify(ensure_2d(x.data_src, reshape=True))
if y is not None and x.shape[1] != y.shape[1]:
raise ValueError(
f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`."
Expand All @@ -94,6 +96,8 @@ def _create_geometry(
)

arr = ensure_2d(x.data_src, reshape=False)
arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr)

if x.is_cost_matrix:
return geometry.Geometry(
cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost
Expand Down
11 changes: 6 additions & 5 deletions src/moscot/utils/tagged_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _extract_data(
*,
attr: Literal["X", "obsp", "obsm", "layers", "uns"],
key: Optional[str] = None,
densify: bool = False,
) -> ArrayLike:
modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]"
data = getattr(adata, attr)
Expand All @@ -67,7 +68,7 @@ def _extract_data(
except IndexError:
raise IndexError(f"Unable to fetch data from `{modifier}`.") from None

if sp.issparse(data):
if sp.issparse(data) and densify:
logger.warning(f"Densifying data in `{modifier}`")
data = data.A
if data.ndim != 2:
Expand Down Expand Up @@ -103,7 +104,7 @@ def from_adata(
"""Create tagged array from :class:`~anndata.AnnData`.

.. warning::
Sparse arrays will be always densified.
Sparse arrays will be densified except when ``tag = 'graph'``.

Parameters
----------
Expand Down Expand Up @@ -137,13 +138,13 @@ def from_adata(
if tag == Tag.GRAPH:
if cost == "geodesic":
dist_key = f"{dist_key[0]}_{dist_key[1]}" if isinstance(dist_key, tuple) else dist_key
data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}")
data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}", densify=False)
return cls(data_src=data, tag=Tag.GRAPH, cost="geodesic")
raise ValueError(f"Expected `cost=geodesic`, found `{cost}`.")
if tag == Tag.COST_MATRIX:
if cost == "custom": # our custom cost functions
modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]"
data = cls._extract_data(adata, attr=attr, key=key)
data = cls._extract_data(adata, attr=attr, key=key, densify=True)
if np.any(data < 0):
raise ValueError(f"Cost matrix in `{modifier}` contains negative values.")
return cls(data_src=data, tag=Tag.COST_MATRIX, cost=None)
Expand All @@ -153,7 +154,7 @@ def from_adata(
return cls(data_src=cost_matrix, tag=Tag.COST_MATRIX, cost=None)

# tag is either a point cloud or a kernel
data = cls._extract_data(adata, attr=attr, key=key)
data = cls._extract_data(adata, attr=attr, key=key, densify=True)
cost_fn = get_cost(cost, backend=backend, **kwargs)
return cls(data_src=data, tag=tag, cost=cost_fn)

Expand Down
23 changes: 23 additions & 0 deletions tests/backends/ott/test_backend_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

import jax.experimental.sparse as jesp
import numpy as np
import scipy.sparse as sp
from ott.geometry.geometry import Geometry

from moscot.backends.ott._utils import _instantiate_geodesic_cost


class TestBackendUtils:

@staticmethod
def test_instantiate_geodesic_cost():
m, n = 10, 10
problem_shape = 10, 10
g = sp.rand(m, n, 0.1, dtype=np.float64)
g = jesp.BCOO.from_scipy_sparse(g)
geom = _instantiate_geodesic_cost(g, problem_shape, 1.0, False)
assert isinstance(geom, Geometry)
with pytest.raises(ValueError, match="Expected `x` to have"):
_instantiate_geodesic_cost(g, problem_shape, 1.0, True)
geom = _instantiate_geodesic_cost(g, (5, 5), 1.0, True)
19 changes: 14 additions & 5 deletions tests/problems/space/test_alignment_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler

import scanpy as sc
Expand Down Expand Up @@ -127,7 +128,8 @@ def test_solve_unbalanced(self, adata_space_rotate: AnnData):
assert np.allclose(*(sol.cost for sol in ap.solutions.values()), rtol=1e-5, atol=1e-5)

@pytest.mark.parametrize("key", ["connectivities", "distances"])
def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str):
@pytest.mark.parametrize("dense_input", [True, False])
def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str, dense_input: bool):
batch_column = "batch"
unique_batches = adata_space_rotate.obs[batch_column].unique()

Expand All @@ -140,13 +142,20 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str):
)[0]
adata_subset = adata_space_rotate[indices]
sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X_pca")
dfs.append(
df = (
pd.DataFrame(
index=adata_subset.obs_names,
columns=adata_subset.obs_names,
data=adata_subset.obsp["connectivities"].A.astype("float64"),
data=adata_subset.obsp[key].A.astype("float64"),
)
if dense_input
else (
adata_subset.obsp[key].astype("float64"),
adata_subset.obs_names.to_series(),
adata_subset.obs_names.to_series(),
)
)
dfs.append(df)

ap: AlignmentProblem = AlignmentProblem(adata=adata_space_rotate)
ap = ap.prepare(batch_key=batch_column, joint_attr={"attr": "obsm", "key": "X_pca"})
Expand All @@ -157,14 +166,14 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str):

ta = ap[("0", "1")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"

ta = ap[("1", "2")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand Down
38 changes: 27 additions & 11 deletions tests/problems/space/test_mapping_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pandas as pd
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler

import anndata as ad
Expand Down Expand Up @@ -134,7 +135,8 @@ def test_solve_balanced(

@pytest.mark.parametrize("key", ["connectivities", "distances"])
@pytest.mark.parametrize("geodesic_y", [True, False])
def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bool):
@pytest.mark.parametrize("dense_input", [True, False])
def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bool, dense_input: bool):
adataref, adatasp = _adata_spatial_split(adata_mapping)

batch_column = "batch"
Expand All @@ -146,20 +148,34 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo
adata_spatial_subset = adatasp[indices]
adata_subset = ad.concat([adata_spatial_subset, adataref])
sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X")
dfs.append(
df = (
pd.DataFrame(
index=adata_subset.obs_names,
columns=adata_subset.obs_names,
data=adata_subset.obsp["connectivities"].A.astype("float64"),
data=adata_subset.obsp[key].A.astype("float64"),
)
if dense_input
else (
adata_subset.obsp[key].astype("float64"),
adata_subset.obs_names.to_series(),
adata_subset.obs_names.to_series(),
)
)
dfs.append(df)

if geodesic_y:
sc.pp.neighbors(adataref, n_neighbors=15, use_rep="X")
df_y = pd.DataFrame(
index=adataref.obs_names,
columns=adataref.obs_names,
data=adataref.obsp["connectivities"].A.astype("float64"),
df_y = (
pd.DataFrame(
index=adataref.obs_names,
columns=adataref.obs_names,
data=adataref.obsp[key].A.astype("float64"),
)
if dense_input
else (
adataref.obsp[key].astype("float64"),
adataref.obs_names.to_series(),
)
)

mp: MappingProblem = MappingProblem(adataref, adatasp)
Expand All @@ -175,28 +191,28 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo

ta = mp[("1", "ref")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
if geodesic_y:
ta = mp[("1", "ref")].y
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"

ta = mp[("2", "ref")].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
if geodesic_y:
ta = mp[("2", "ref")].y
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand Down
17 changes: 13 additions & 4 deletions tests/problems/time/test_temporal_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import jax.numpy as jnp
import numpy as np
import pandas as pd
import scipy.sparse as sp
from ott.geometry import costs, epsilon_scheduler
from scipy.sparse import csr_matrix

Expand Down Expand Up @@ -249,7 +250,8 @@ def test_result_compares_to_wot(self, gt_temporal_adata: AnnData):
np.array(tp[key_1, key_3].solution.transport_matrix),
)

def test_geodesic_cost_set_xy_cost_dense(self, adata_time):
@pytest.mark.parametrize("dense_input", [True, False])
def test_geodesic_cost_set_xy_cost(self, adata_time, dense_input):
# TODO(@MUCDK) add test for failure case
tp = TemporalProblem(adata_time)
tp = tp.prepare("time", joint_attr="X_pca")
Expand All @@ -264,20 +266,27 @@ def test_geodesic_cost_set_xy_cost_dense(self, adata_time):
indices = np.where((adata_time.obs[batch_column] == batch1) | (adata_time.obs[batch_column] == batch2))[0]
adata_subset = adata_time[indices]
sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X_pca")
dfs.append(
df = (
pd.DataFrame(
index=adata_subset.obs_names,
columns=adata_subset.obs_names,
data=adata_subset.obsp["connectivities"].A.astype("float64"),
)
if dense_input
else (
adata_subset.obsp["connectivities"].astype("float64"),
adata_subset.obs_names.to_series(),
adata_subset.obs_names.to_series(),
)
)
dfs.append(df)

tp[0, 1].set_graph_xy(dfs[0], cost="geodesic")
tp = tp.solve(max_iterations=2, lse_mode=False)

ta = tp[0, 1].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray)
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand All @@ -287,7 +296,7 @@ def test_geodesic_cost_set_xy_cost_dense(self, adata_time):

ta = tp[1, 2].xy
assert isinstance(ta, TaggedArray)
assert isinstance(ta.data_src, np.ndarray)
assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src)
assert ta.data_tgt is None
assert ta.tag == Tag.GRAPH
assert ta.cost == "geodesic"
Expand Down
Loading