From 1b951fe389d910917f20c92cacee406ee79ac1f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Selman=20=C3=96zleyen?= <32667648+selmanozleyen@users.noreply.github.com> Date: Mon, 20 May 2024 20:36:33 +0200 Subject: [PATCH] Add Sparse Geodesic Cost Support (#677) * set ottjax version first and the tests * recreate solution files with new ottjax version so it doesn't fail * update tests with new version, specify inner_iterations * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comments on issues * first draft of geodesic sparse need to add more tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * set to new ottjax commit * revert old commit that fixes the lr case * set new version of ottjax instead of commit * fix and modify the geodesic cost tests * typo fix * fix doc and add tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * still densify for other costs * fix docs * add the final tests * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Dominik Klein --- pyproject.toml | 2 +- src/moscot/backends/ott/_utils.py | 30 +++++++++++++-- src/moscot/backends/ott/solver.py | 8 +++- src/moscot/utils/tagged_array.py | 11 +++--- tests/backends/ott/test_backend_utils.py | 23 +++++++++++ .../problems/space/test_alignment_problem.py | 19 +++++++--- tests/problems/space/test_mapping_problem.py | 38 +++++++++++++------ tests/problems/time/test_temporal_problem.py | 17 +++++++-- 8 files changed, 117 insertions(+), 31 deletions(-) create mode 100644 tests/backends/ott/test_backend_utils.py diff --git a/pyproject.toml b/pyproject.toml index 3190970f9..84f8a9ccb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "scanpy>=1.9.3", "wrapt>=1.13.2", "docrep>=0.3.2", - "ott-jax>=0.4.5", + "ott-jax>=0.4.6", "cloudpickle>=2.2.0", "rich>=13.5", ] diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index fc623ac8e..34cd5916b 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -1,6 +1,7 @@ from typing import Any, Literal, Optional, Tuple, Union import jax +import jax.experimental.sparse as jesp import jax.numpy as jnp import scipy.sparse as sp from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud @@ -67,6 +68,25 @@ def alpha_to_fused_penalty(alpha: float) -> float: return (1 - alpha) / alpha +def densify(arr: ArrayLike) -> jax.Array: + """If the input is sparse, convert it to dense. + + Parameters + ---------- + arr + Array to check. + + Returns + ------- + dense :mod:`jax` array. + """ + if sp.issparse(arr): + arr = arr.A # type: ignore[attr-defined] + elif isinstance(arr, jesp.BCOO): + arr = arr.todense() + return jnp.asarray(arr) + + def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: """Ensure that an array is 2-dimensional. @@ -81,9 +101,6 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: ------- 2-dimensional :mod:`jax` array. """ - if sp.issparse(arr): - arr = arr.A # type: ignore[attr-defined] - arr = jnp.asarray(arr) if reshape and arr.ndim == 1: return jnp.reshape(arr, (-1, 1)) if arr.ndim != 2: @@ -91,6 +108,13 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: return arr +def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO: + """If the input is a scipy sparse matrix, convert it to a jax BCOO.""" + if sp.issparse(arr): + return jesp.BCOO.from_scipy_sparse(arr) + return arr + + def _instantiate_geodesic_cost( arr: jax.Array, problem_shape: Tuple[int, int], diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index f8130eab2..9d280d892 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -16,6 +16,8 @@ _instantiate_geodesic_cost, alpha_to_fused_penalty, check_shapes, + convert_scipy_sparse, + densify, ensure_2d, ) from moscot.backends.ott.output import GraphOTTOutput, OTTOutput @@ -76,8 +78,8 @@ def _create_geometry( if not isinstance(cost_fn, costs.CostFn): raise TypeError(f"Expected `cost_fn` to be `ott.geometry.costs.CostFn`, found `{type(cost_fn)}`.") - y = None if x.data_tgt is None else ensure_2d(x.data_tgt, reshape=True) - x = ensure_2d(x.data_src, reshape=True) + y = None if x.data_tgt is None else densify(ensure_2d(x.data_tgt, reshape=True)) + x = densify(ensure_2d(x.data_src, reshape=True)) if y is not None and x.shape[1] != y.shape[1]: raise ValueError( f"Expected `x/y` to have the same number of dimensions, found `{x.shape[1]}/{y.shape[1]}`." @@ -94,6 +96,8 @@ def _create_geometry( ) arr = ensure_2d(x.data_src, reshape=False) + arr = densify(arr) if x.is_graph else convert_scipy_sparse(arr) + if x.is_cost_matrix: return geometry.Geometry( cost_matrix=arr, epsilon=epsilon, relative_epsilon=relative_epsilon, scale_cost=scale_cost diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 2bb62f682..7f89fa073 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -55,6 +55,7 @@ def _extract_data( *, attr: Literal["X", "obsp", "obsm", "layers", "uns"], key: Optional[str] = None, + densify: bool = False, ) -> ArrayLike: modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" data = getattr(adata, attr) @@ -67,7 +68,7 @@ def _extract_data( except IndexError: raise IndexError(f"Unable to fetch data from `{modifier}`.") from None - if sp.issparse(data): + if sp.issparse(data) and densify: logger.warning(f"Densifying data in `{modifier}`") data = data.A if data.ndim != 2: @@ -103,7 +104,7 @@ def from_adata( """Create tagged array from :class:`~anndata.AnnData`. .. warning:: - Sparse arrays will be always densified. + Sparse arrays will be densified except when ``tag = 'graph'``. Parameters ---------- @@ -137,13 +138,13 @@ def from_adata( if tag == Tag.GRAPH: if cost == "geodesic": dist_key = f"{dist_key[0]}_{dist_key[1]}" if isinstance(dist_key, tuple) else dist_key - data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}") + data = cls._extract_data(adata, attr=attr, key=f"{dist_key}_{key}", densify=False) return cls(data_src=data, tag=Tag.GRAPH, cost="geodesic") raise ValueError(f"Expected `cost=geodesic`, found `{cost}`.") if tag == Tag.COST_MATRIX: if cost == "custom": # our custom cost functions modifier = f"adata.{attr}" if key is None else f"adata.{attr}[{key!r}]" - data = cls._extract_data(adata, attr=attr, key=key) + data = cls._extract_data(adata, attr=attr, key=key, densify=True) if np.any(data < 0): raise ValueError(f"Cost matrix in `{modifier}` contains negative values.") return cls(data_src=data, tag=Tag.COST_MATRIX, cost=None) @@ -153,7 +154,7 @@ def from_adata( return cls(data_src=cost_matrix, tag=Tag.COST_MATRIX, cost=None) # tag is either a point cloud or a kernel - data = cls._extract_data(adata, attr=attr, key=key) + data = cls._extract_data(adata, attr=attr, key=key, densify=True) cost_fn = get_cost(cost, backend=backend, **kwargs) return cls(data_src=data, tag=tag, cost=cost_fn) diff --git a/tests/backends/ott/test_backend_utils.py b/tests/backends/ott/test_backend_utils.py new file mode 100644 index 000000000..c7296a7b6 --- /dev/null +++ b/tests/backends/ott/test_backend_utils.py @@ -0,0 +1,23 @@ +import pytest + +import jax.experimental.sparse as jesp +import numpy as np +import scipy.sparse as sp +from ott.geometry.geometry import Geometry + +from moscot.backends.ott._utils import _instantiate_geodesic_cost + + +class TestBackendUtils: + + @staticmethod + def test_instantiate_geodesic_cost(): + m, n = 10, 10 + problem_shape = 10, 10 + g = sp.rand(m, n, 0.1, dtype=np.float64) + g = jesp.BCOO.from_scipy_sparse(g) + geom = _instantiate_geodesic_cost(g, problem_shape, 1.0, False) + assert isinstance(geom, Geometry) + with pytest.raises(ValueError, match="Expected `x` to have"): + _instantiate_geodesic_cost(g, problem_shape, 1.0, True) + geom = _instantiate_geodesic_cost(g, (5, 5), 1.0, True) diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 163cc99b1..d6a78c063 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import scipy.sparse as sp from ott.geometry import epsilon_scheduler import scanpy as sc @@ -127,7 +128,8 @@ def test_solve_unbalanced(self, adata_space_rotate: AnnData): assert np.allclose(*(sol.cost for sol in ap.solutions.values()), rtol=1e-5, atol=1e-5) @pytest.mark.parametrize("key", ["connectivities", "distances"]) - def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str): + @pytest.mark.parametrize("dense_input", [True, False]) + def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str, dense_input: bool): batch_column = "batch" unique_batches = adata_space_rotate.obs[batch_column].unique() @@ -140,13 +142,20 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str): )[0] adata_subset = adata_space_rotate[indices] sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X_pca") - dfs.append( + df = ( pd.DataFrame( index=adata_subset.obs_names, columns=adata_subset.obs_names, - data=adata_subset.obsp["connectivities"].A.astype("float64"), + data=adata_subset.obsp[key].A.astype("float64"), + ) + if dense_input + else ( + adata_subset.obsp[key].astype("float64"), + adata_subset.obs_names.to_series(), + adata_subset.obs_names.to_series(), ) ) + dfs.append(df) ap: AlignmentProblem = AlignmentProblem(adata=adata_space_rotate) ap = ap.prepare(batch_key=batch_column, joint_attr={"attr": "obsm", "key": "X_pca"}) @@ -157,14 +166,14 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str): ta = ap[("0", "1")].xy assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" ta = ap[("1", "2")].xy assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 1922d8bc1..903c02611 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd +import scipy.sparse as sp from ott.geometry import epsilon_scheduler import anndata as ad @@ -134,7 +135,8 @@ def test_solve_balanced( @pytest.mark.parametrize("key", ["connectivities", "distances"]) @pytest.mark.parametrize("geodesic_y", [True, False]) - def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bool): + @pytest.mark.parametrize("dense_input", [True, False]) + def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bool, dense_input: bool): adataref, adatasp = _adata_spatial_split(adata_mapping) batch_column = "batch" @@ -146,20 +148,34 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo adata_spatial_subset = adatasp[indices] adata_subset = ad.concat([adata_spatial_subset, adataref]) sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X") - dfs.append( + df = ( pd.DataFrame( index=adata_subset.obs_names, columns=adata_subset.obs_names, - data=adata_subset.obsp["connectivities"].A.astype("float64"), + data=adata_subset.obsp[key].A.astype("float64"), + ) + if dense_input + else ( + adata_subset.obsp[key].astype("float64"), + adata_subset.obs_names.to_series(), + adata_subset.obs_names.to_series(), ) ) + dfs.append(df) if geodesic_y: sc.pp.neighbors(adataref, n_neighbors=15, use_rep="X") - df_y = pd.DataFrame( - index=adataref.obs_names, - columns=adataref.obs_names, - data=adataref.obsp["connectivities"].A.astype("float64"), + df_y = ( + pd.DataFrame( + index=adataref.obs_names, + columns=adataref.obs_names, + data=adataref.obsp[key].A.astype("float64"), + ) + if dense_input + else ( + adataref.obsp[key].astype("float64"), + adataref.obs_names.to_series(), + ) ) mp: MappingProblem = MappingProblem(adataref, adatasp) @@ -175,28 +191,28 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo ta = mp[("1", "ref")].xy assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" if geodesic_y: ta = mp[("1", "ref")].y assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" ta = mp[("2", "ref")].xy assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" if geodesic_y: ta = mp[("2", "ref")].y assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) # this will change once OTT-JAX allows for sparse matrices + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index a11e2a93c..7858eb613 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -5,6 +5,7 @@ import jax.numpy as jnp import numpy as np import pandas as pd +import scipy.sparse as sp from ott.geometry import costs, epsilon_scheduler from scipy.sparse import csr_matrix @@ -249,7 +250,8 @@ def test_result_compares_to_wot(self, gt_temporal_adata: AnnData): np.array(tp[key_1, key_3].solution.transport_matrix), ) - def test_geodesic_cost_set_xy_cost_dense(self, adata_time): + @pytest.mark.parametrize("dense_input", [True, False]) + def test_geodesic_cost_set_xy_cost(self, adata_time, dense_input): # TODO(@MUCDK) add test for failure case tp = TemporalProblem(adata_time) tp = tp.prepare("time", joint_attr="X_pca") @@ -264,20 +266,27 @@ def test_geodesic_cost_set_xy_cost_dense(self, adata_time): indices = np.where((adata_time.obs[batch_column] == batch1) | (adata_time.obs[batch_column] == batch2))[0] adata_subset = adata_time[indices] sc.pp.neighbors(adata_subset, n_neighbors=15, use_rep="X_pca") - dfs.append( + df = ( pd.DataFrame( index=adata_subset.obs_names, columns=adata_subset.obs_names, data=adata_subset.obsp["connectivities"].A.astype("float64"), ) + if dense_input + else ( + adata_subset.obsp["connectivities"].astype("float64"), + adata_subset.obs_names.to_series(), + adata_subset.obs_names.to_series(), + ) ) + dfs.append(df) tp[0, 1].set_graph_xy(dfs[0], cost="geodesic") tp = tp.solve(max_iterations=2, lse_mode=False) ta = tp[0, 1].xy assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic" @@ -287,7 +296,7 @@ def test_geodesic_cost_set_xy_cost_dense(self, adata_time): ta = tp[1, 2].xy assert isinstance(ta, TaggedArray) - assert isinstance(ta.data_src, np.ndarray) + assert isinstance(ta.data_src, np.ndarray) if dense_input else sp.issparse(ta.data_src) assert ta.data_tgt is None assert ta.tag == Tag.GRAPH assert ta.cost == "geodesic"