From ca92568790967e61a9068aea20b95af9d7524b46 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 26 Aug 2024 12:23:20 +0200 Subject: [PATCH 01/25] add pancreas dataset --- docs/conf.py | 1 + pyproject.toml | 3 +- src/moscot/datasets.py | 118 ++++++++++++++++++++++++++++++++--------- 3 files changed, 96 insertions(+), 26 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 311d1f6b2..34fe851cd 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -55,6 +55,7 @@ "anndata": ("https://anndata.readthedocs.io/en/latest/", None), "scanpy": ("https://scanpy.readthedocs.io/en/latest/", None), "squidpy": ("https://squidpy.readthedocs.io/en/latest/", None), + "mudata": ("https://mudata.readthedocs.io/en/latest/", None), } master_doc = "index" pygments_style = "tango" diff --git a/pyproject.toml b/pyproject.toml index 014775ffe..f3b5e3291 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,8 @@ dependencies = [ "ott-jax[neural]>=0.4.6", "cloudpickle>=2.2.0", "rich>=13.5", - "docstring_inheritance>=2.0.0" + "docstring_inheritance>=2.0.0", + "mudata>=0.3.0" ] [project.optional-dependencies] diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index a308d5867..b2c341807 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -7,7 +7,10 @@ import urllib.request from itertools import combinations from types import MappingProxyType -from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple +from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union + +import mudata +import mudata as mu import networkx as nx import numpy as np @@ -15,8 +18,6 @@ from scipy.linalg import block_diag import anndata as ad -from anndata import AnnData -from scanpy import read from moscot._types import PathLike @@ -36,7 +37,7 @@ def mosta( path: PathLike = "~/.cache/moscot/mosta.h5ad", force_download: bool = False, **kwargs: Any, -) -> AnnData: # pragma: no cover +) -> ad.AnnData: # pragma: no cover """Preprocessed and extracted data as provided in :cite:`chen:22`. Includes embryo sections `E9.5`, `E2S1`, `E10.5`, `E2S1`, `E11.5`, `E1S2`. @@ -59,6 +60,7 @@ def mosta( """ return _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/40569779", expected_shape=(54134, 2000), force_download=force_download, @@ -70,7 +72,7 @@ def hspc( path: PathLike = "~/.cache/moscot/hspc.h5ad", force_download: bool = False, **kwargs: Any, -) -> AnnData: # pragma: no cover +) -> ad.AnnData: # pragma: no cover """CD34+ hematopoietic stem and progenitor cells from 4 healthy human donors. From the `NeurIPS Multimodal Single-Cell Integration Challenge @@ -95,6 +97,7 @@ def hspc( """ dataset = _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/37993503", expected_shape=(4000, 2000), force_download=force_download, @@ -111,7 +114,7 @@ def drosophila( spatial: bool, force_download: bool = False, **kwargs: Any, -) -> AnnData: +) -> ad.AnnData: """Embryo of Drosophila melanogaster described in :cite:`Li-spatial:22`. Minimal pre-processing was performed, such as gene and cell filtering, as well as normalization. @@ -135,6 +138,7 @@ def drosophila( if spatial: return _load_dataset_from_url( path + "_sp.h5ad", + type="h5ad", backup_url="https://figshare.com/ndownloader/files/37984935", expected_shape=(3039, 82), force_download=force_download, @@ -143,6 +147,7 @@ def drosophila( return _load_dataset_from_url( path + "_sc.h5ad", + type="h5ad", backup_url="https://figshare.com/ndownloader/files/37984938", expected_shape=(1297, 2000), force_download=force_download, @@ -154,7 +159,7 @@ def c_elegans( path: PathLike = "~/.cache/moscot/c_elegans.h5ad", force_download: bool = False, **kwargs: Any, -) -> Tuple[AnnData, nx.DiGraph]: # pragma: no cover +) -> Tuple[ad.AnnData, nx.DiGraph]: # pragma: no cover """scRNA-seq time-series dataset of C.elegans embryogenesis :cite:`packer:19`. Contains raw counts of 46,151 cells with at least partial lineage information. @@ -175,6 +180,7 @@ def c_elegans( """ adata = _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/39943585", expected_shape=(46151, 20222), force_download=force_download, @@ -191,7 +197,7 @@ def zebrafish( path: PathLike = "~/.cache/moscot/zebrafish.h5ad", force_download: bool = False, **kwargs: Any, -) -> Tuple[AnnData, Dict[str, nx.DiGraph]]: +) -> Tuple[ad.AnnData, Dict[str, nx.DiGraph]]: """Lineage-traced scRNA-seq time-series dataset of Zebrafish heart regeneration :cite:`hu:22`. Contains gene expression vectors, LINNAEUS :cite:`spanjaard:18` reconstructed lineage trees, @@ -212,6 +218,7 @@ def zebrafish( """ adata = _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/39951073", expected_shape=(44014, 31466), force_download=force_download, @@ -230,7 +237,7 @@ def bone_marrow( rna: bool, force_download: bool = False, **kwargs: Any, -) -> AnnData: +) -> ad.AnnData: """Multiome data of bone marrow measurements :cite:`luecken:21`. Contains processed counts of 6,224 cells. The RNA data was filtered to 2,000 top @@ -256,6 +263,7 @@ def bone_marrow( if rna: return _load_dataset_from_url( path + "_rna.h5ad", + type="h5ad", backup_url="https://figshare.com/ndownloader/files/40195114", expected_shape=(6224, 2000), force_download=force_download, @@ -263,6 +271,7 @@ def bone_marrow( ) return _load_dataset_from_url( path + "_atac.h5ad", + type="h5ad", backup_url="https://figshare.com/ndownloader/files/41013551", expected_shape=(6224, 8000), force_download=force_download, @@ -270,11 +279,56 @@ def bone_marrow( ) +def pancreas_multiome( + rna_only: bool, + path: PathLike = "~/.cache/moscot/pancreas_multiome.h5mu", + force_download: bool = True, + **kwargs: Any, +) -> Union[mu.MuData, ad.AnnData]: # pragma: no cover + """Pancreatic endocrinogenesis dataset published with the moscot manuscript :cite:`Klein:23`. + + The dataset contains paired scRNA-seq and scATAC-seq data of pancreatic cells at embryonic days 14.5, 15.5, and + 16.5. The data was preprocessed and filtered as described in the manuscript, the raw data and the full processed + data are available via GEO accession code GSE275562. + + Parameters + ---------- + rna_only + Only load the RNA data, resulting in a smaller file. + path + Path where to save the file. + force_download + Whether to force-download the data. + kwargs + Keyword arguments for :func:`anndata.read_h5ad` if `rna_only`, else for :func:`mudata.read`. + + Returns + ------- + :class:`mudata.MuData` object with RNA and ATAC data if `rna_only`, else :class:`anndata.AnnData` with RNA only. + """ + if rna_only: + return _load_dataset_from_url( + path, + type="h5ad", + backup_url="https://figshare.com/ndownloader/files/48785320", + expected_shape=(22604, 20242), + force_download=force_download, + **kwargs, + ) + return _load_dataset_from_url( + path, + type="h5mu", + backup_url="https://figshare.com/ndownloader/files/48782332", + expected_shape=(22604, 271918), + force_download=force_download, + ) + + def tedsim( path: PathLike = "~/.cache/moscot/tedsim.h5ad", force_download: bool = False, **kwargs: Any, -) -> AnnData: # pragma: no cover +) -> ad.AnnData: # pragma: no cover """Dataset simulated with TedSim :cite:`pan:22`. Simulated scRNA-seq dataset of a differentiation trajectory. For each cell, the dataset includes a (raw counts) @@ -302,6 +356,7 @@ def tedsim( """ return _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/40178644", expected_shape=(8448, 500), force_download=force_download, @@ -313,7 +368,7 @@ def sciplex( path: PathLike = "~/.cache/moscot/sciplex.h5ad", force_download: bool = False, **kwargs: Any, -) -> AnnData: # pragma: no cover +) -> ad.AnnData: # pragma: no cover """Perturbation dataset published in :cite:`srivatsan:20`. Transcriptomes of A549, K562, and mCF7 cells exposed to 188 compounds. @@ -334,6 +389,7 @@ def sciplex( """ return _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/43381398", expected_shape=(799317, 110984), force_download=force_download, @@ -345,7 +401,7 @@ def sim_align( path: PathLike = "~/.cache/moscot/sim_align.h5ad", force_download: bool = False, **kwargs: Any, -) -> AnnData: # pragma: no cover +) -> ad.AnnData: # pragma: no cover """Spatial transcriptomics simulated dataset as described in :cite:`Jones-spatial:22`. Parameters @@ -363,6 +419,7 @@ def sim_align( """ return _load_dataset_from_url( path, + type="h5ad", backup_url="https://figshare.com/ndownloader/files/37984926", expected_shape=(1200, 500), force_download=force_download, @@ -383,7 +440,7 @@ def simulate_data( lin_cost_matrix: Optional[str] = None, quad_cost_matrix: Optional[str] = None, **kwargs: Any, -) -> AnnData: +) -> ad.AnnData: """Simulate data. This function is used to generate data, mainly for the purpose of @@ -424,7 +481,7 @@ def simulate_data( """ rng = np.random.RandomState(seed) adatas = [ - AnnData( + ad.AnnData( X=rng.multivariate_normal( mean=kwargs.pop("mean", np.arange(n_genes)), cov=kwargs.pop("cov", var * np.diag(np.ones(n_genes))), @@ -477,32 +534,43 @@ def simulate_data( def _load_dataset_from_url( fpath: PathLike, + type: Literal["h5ad", "h5mu"], *, backup_url: str, expected_shape: Tuple[int, int], force_download: bool = False, - sparse: bool = True, - cache: bool = True, **kwargs: Any, -) -> AnnData: +) -> Union[ad.AnnData, mu.MuData]: + # TODO: make nicer once https://github.com/scverse/mudata/issues/76 resolved fpath = os.path.expanduser(fpath) - if not fpath.endswith(".h5ad"): + if type == "h5ad" and not fpath.endswith(".h5ad"): fpath += ".h5ad" - if force_download: + if type == "h5mu" and not fpath.endswith(".h5mu"): + fpath += ".h5mu" + + if not os.path.exists(fpath) or force_download: with tempfile.TemporaryDirectory() as tmpdir: - tmp = pathlib.Path(tmpdir) / "data.h5ad" - adata = read(filename=tmp, backup_url=backup_url, sparse=sparse, cache=cache, **kwargs) + tmp = pathlib.Path(tmpdir) / f"data.{type}" + urllib.request.urlretrieve(backup_url, tmp) + if type == "h5ad": + data = ad.read_h5ad(filename=tmp, **kwargs) + if type == "h5mu": + data = mudata.read(tmp, **kwargs) with contextlib.suppress(FileNotFoundError): os.remove(fpath) shutil.move(tmp, fpath) else: - adata = read(filename=fpath, backup_url=backup_url, sparse=sparse, cache=cache, **kwargs) + if type == "h5ad": + data = ad.read_h5ad(filename=fpath, **kwargs) + else: + raise NotImplementedError("MuData download only available with `force_download=True`.") - if adata.shape != expected_shape: - raise ValueError(f"Expected `AnnData` object to have shape `{expected_shape}`, found `{adata.shape}`.") + if data.shape != expected_shape: + data_str = "MuData" if type == "h5mu" else "AnnData" + raise ValueError(f"Expected {data_str} object to have shape `{expected_shape}`, found `{data.shape}`.") - return adata + return data def _get_random_trees( From 1ebc89cc6bdc71fd540b92dd92cd8caa9e45d933 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 26 Aug 2024 12:25:17 +0200 Subject: [PATCH 02/25] fix logic --- src/moscot/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index b2c341807..c172a8166 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -549,7 +549,7 @@ def _load_dataset_from_url( if type == "h5mu" and not fpath.endswith(".h5mu"): fpath += ".h5mu" - if not os.path.exists(fpath) or force_download: + if not os.path.exists(fpath): with tempfile.TemporaryDirectory() as tmpdir: tmp = pathlib.Path(tmpdir) / f"data.{type}" urllib.request.urlretrieve(backup_url, tmp) From c699dc955f7e3d992fa26c8440f8dfe67ed5a498 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 26 Aug 2024 13:20:29 +0200 Subject: [PATCH 03/25] adapt mudata dep --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f3b5e3291..d816f3260 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ dependencies = [ "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0", - "mudata>=0.3.0" + "mudata>=0.2.0" ] [project.optional-dependencies] From b885e3cec2761470dc64bc0af4feb2481725deb7 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 9 Sep 2024 14:08:31 +0200 Subject: [PATCH 04/25] commit to test if it works with scanpy download function --- src/moscot/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index c172a8166..03109ecc2 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -548,11 +548,11 @@ def _load_dataset_from_url( if type == "h5mu" and not fpath.endswith(".h5mu"): fpath += ".h5mu" - + from scanpy.readwrite import _check_datafile_present_and_download if not os.path.exists(fpath): with tempfile.TemporaryDirectory() as tmpdir: tmp = pathlib.Path(tmpdir) / f"data.{type}" - urllib.request.urlretrieve(backup_url, tmp) + _check_datafile_present_and_download(backup_url=backup_url, path=tmp) if type == "h5ad": data = ad.read_h5ad(filename=tmp, **kwargs) if type == "h5mu": From d9ccf6016e7fc9f9b17d177521ef53bf93d6fbc5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:08:50 +0000 Subject: [PATCH 05/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 03109ecc2..54b54dc4c 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -549,6 +549,7 @@ def _load_dataset_from_url( if type == "h5mu" and not fpath.endswith(".h5mu"): fpath += ".h5mu" from scanpy.readwrite import _check_datafile_present_and_download + if not os.path.exists(fpath): with tempfile.TemporaryDirectory() as tmpdir: tmp = pathlib.Path(tmpdir) / f"data.{type}" From dc84ff743ec4ecd75cd4d8c78d087ae0deed461b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 9 Sep 2024 14:22:25 +0200 Subject: [PATCH 06/25] check again --- src/moscot/datasets.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 54b54dc4c..a90208864 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -550,22 +550,10 @@ def _load_dataset_from_url( fpath += ".h5mu" from scanpy.readwrite import _check_datafile_present_and_download - if not os.path.exists(fpath): - with tempfile.TemporaryDirectory() as tmpdir: - tmp = pathlib.Path(tmpdir) / f"data.{type}" - _check_datafile_present_and_download(backup_url=backup_url, path=tmp) - if type == "h5ad": - data = ad.read_h5ad(filename=tmp, **kwargs) - if type == "h5mu": - data = mudata.read(tmp, **kwargs) - with contextlib.suppress(FileNotFoundError): - os.remove(fpath) - shutil.move(tmp, fpath) + if not os.path.exists(fpath) or force_download: + _check_datafile_present_and_download(backup_url=backup_url, path=fpath) else: - if type == "h5ad": - data = ad.read_h5ad(filename=fpath, **kwargs) - else: - raise NotImplementedError("MuData download only available with `force_download=True`.") + data = ad.read_h5ad(filename=fpath, **kwargs) if type == "h5ad" else mu.read_h5mu(filename=fpath, backed=False) if data.shape != expected_shape: data_str = "MuData" if type == "h5mu" else "AnnData" From db5b81d79f7c68b5e0fffb02d721d47bfe4f02ee Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:22:50 +0000 Subject: [PATCH 07/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/datasets.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index a90208864..38c0cfe8a 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -1,15 +1,10 @@ -import contextlib import os -import pathlib import pickle -import shutil -import tempfile import urllib.request from itertools import combinations from types import MappingProxyType from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union -import mudata import mudata as mu import networkx as nx From c028bc720d5edf76654305a028f6a57f81e9e06a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 9 Sep 2024 14:25:56 +0200 Subject: [PATCH 08/25] again --- src/moscot/datasets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 38c0cfe8a..25051cf22 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -544,11 +544,9 @@ def _load_dataset_from_url( if type == "h5mu" and not fpath.endswith(".h5mu"): fpath += ".h5mu" from scanpy.readwrite import _check_datafile_present_and_download - if not os.path.exists(fpath) or force_download: _check_datafile_present_and_download(backup_url=backup_url, path=fpath) - else: - data = ad.read_h5ad(filename=fpath, **kwargs) if type == "h5ad" else mu.read_h5mu(filename=fpath, backed=False) + data = ad.read_h5ad(filename=fpath, **kwargs) if type == "h5ad" else mu.read_h5mu(filename=fpath, backed=False) if data.shape != expected_shape: data_str = "MuData" if type == "h5mu" else "AnnData" From 40d7b8ec09e44c9597641a955097fb17f55f7b8f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:27:05 +0000 Subject: [PATCH 09/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 25051cf22..8d6ab4ef7 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -544,6 +544,7 @@ def _load_dataset_from_url( if type == "h5mu" and not fpath.endswith(".h5mu"): fpath += ".h5mu" from scanpy.readwrite import _check_datafile_present_and_download + if not os.path.exists(fpath) or force_download: _check_datafile_present_and_download(backup_url=backup_url, path=fpath) data = ad.read_h5ad(filename=fpath, **kwargs) if type == "h5ad" else mu.read_h5mu(filename=fpath, backed=False) From 0469bc50b84304bf06e0bddb4c4498f6f48ec476 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 9 Sep 2024 14:50:13 +0200 Subject: [PATCH 10/25] add force_download functionality and avoid using reserved name `type` --- src/moscot/datasets.py | 51 +++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index 8d6ab4ef7..e1f97c4df 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -16,6 +16,9 @@ from moscot._types import PathLike +from scanpy.readwrite import _check_datafile_present_and_download + + __all__ = [ "mosta", "hspc", @@ -55,7 +58,7 @@ def mosta( """ return _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/40569779", expected_shape=(54134, 2000), force_download=force_download, @@ -92,7 +95,7 @@ def hspc( """ dataset = _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/37993503", expected_shape=(4000, 2000), force_download=force_download, @@ -133,7 +136,7 @@ def drosophila( if spatial: return _load_dataset_from_url( path + "_sp.h5ad", - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/37984935", expected_shape=(3039, 82), force_download=force_download, @@ -142,7 +145,7 @@ def drosophila( return _load_dataset_from_url( path + "_sc.h5ad", - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/37984938", expected_shape=(1297, 2000), force_download=force_download, @@ -175,7 +178,7 @@ def c_elegans( """ adata = _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/39943585", expected_shape=(46151, 20222), force_download=force_download, @@ -213,7 +216,7 @@ def zebrafish( """ adata = _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/39951073", expected_shape=(44014, 31466), force_download=force_download, @@ -258,7 +261,7 @@ def bone_marrow( if rna: return _load_dataset_from_url( path + "_rna.h5ad", - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/40195114", expected_shape=(6224, 2000), force_download=force_download, @@ -266,7 +269,7 @@ def bone_marrow( ) return _load_dataset_from_url( path + "_atac.h5ad", - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/41013551", expected_shape=(6224, 8000), force_download=force_download, @@ -304,7 +307,7 @@ def pancreas_multiome( if rna_only: return _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/48785320", expected_shape=(22604, 20242), force_download=force_download, @@ -312,7 +315,7 @@ def pancreas_multiome( ) return _load_dataset_from_url( path, - type="h5mu", + file_type="h5mu", backup_url="https://figshare.com/ndownloader/files/48782332", expected_shape=(22604, 271918), force_download=force_download, @@ -351,7 +354,7 @@ def tedsim( """ return _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/40178644", expected_shape=(8448, 500), force_download=force_download, @@ -384,7 +387,7 @@ def sciplex( """ return _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/43381398", expected_shape=(799317, 110984), force_download=force_download, @@ -414,7 +417,7 @@ def sim_align( """ return _load_dataset_from_url( path, - type="h5ad", + file_type="h5ad", backup_url="https://figshare.com/ndownloader/files/37984926", expected_shape=(1200, 500), force_download=force_download, @@ -529,7 +532,7 @@ def simulate_data( def _load_dataset_from_url( fpath: PathLike, - type: Literal["h5ad", "h5mu"], + file_type: Literal["h5ad", "h5mu"], *, backup_url: str, expected_shape: Tuple[int, int], @@ -538,19 +541,17 @@ def _load_dataset_from_url( ) -> Union[ad.AnnData, mu.MuData]: # TODO: make nicer once https://github.com/scverse/mudata/issues/76 resolved fpath = os.path.expanduser(fpath) - if type == "h5ad" and not fpath.endswith(".h5ad"): - fpath += ".h5ad" - - if type == "h5mu" and not fpath.endswith(".h5mu"): - fpath += ".h5mu" - from scanpy.readwrite import _check_datafile_present_and_download - - if not os.path.exists(fpath) or force_download: - _check_datafile_present_and_download(backup_url=backup_url, path=fpath) - data = ad.read_h5ad(filename=fpath, **kwargs) if type == "h5ad" else mu.read_h5mu(filename=fpath, backed=False) + assert file_type in ["h5ad", "h5mu"], f"Invalid type `{file_type}`. Must be one of `['h5ad', 'h5mu']`." + if not fpath.endswith(file_type): + fpath += f".{file_type}" + if force_download and os.path.exists(fpath): + os.remove(fpath) + if not _check_datafile_present_and_download(backup_url=backup_url, path=fpath): + raise FileNotFoundError(f"File `{fpath}` not found or download failed.") + data = ad.read_h5ad(filename=fpath, **kwargs) if file_type == "h5ad" else mu.read_h5mu(filename=fpath, backed=False) if data.shape != expected_shape: - data_str = "MuData" if type == "h5mu" else "AnnData" + data_str = "MuData" if file_type == "h5mu" else "AnnData" raise ValueError(f"Expected {data_str} object to have shape `{expected_shape}`, found `{data.shape}`.") return data From 40e63a4f34f714f27ed80b6ec6b1ca44b8c61e98 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 12:51:33 +0000 Subject: [PATCH 11/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/datasets.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index e1f97c4df..58c1c8870 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -13,11 +13,9 @@ from scipy.linalg import block_diag import anndata as ad - -from moscot._types import PathLike - from scanpy.readwrite import _check_datafile_present_and_download +from moscot._types import PathLike __all__ = [ "mosta", From 73272be9fe2498ecdd8a85e5ddb14cdc28b4a5d3 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 9 Sep 2024 15:22:16 +0200 Subject: [PATCH 12/25] try this version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 10be37fce..faa5d4b35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,7 @@ dependencies = [ "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0", - "mudata>=0.2.0" + "mudata>=0.2.2" ] [project.optional-dependencies] From c010e297538d7c5789bb433ec637e8191a500a27 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 9 Sep 2024 15:52:59 +0200 Subject: [PATCH 13/25] set versions --- .github/workflows/test.yml | 4 ++-- docs/installation.rst | 2 +- pyproject.toml | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e27a02960..4b61b858a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,10 +19,10 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python: ["3.9", "3.10"] + python: ["3.10", "3.11"] include: - os: macos-latest - python: "3.9" + python: "3.10" steps: - uses: actions/checkout@v3 diff --git a/docs/installation.rst b/docs/installation.rst index 5d0e416b5..344cc9efb 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,6 +1,6 @@ Installation ============ -:mod:`moscot` requires Python version >= 3.9 to run. +:mod:`moscot` requires Python version >= 3.10 to run. PyPI ---- diff --git a/pyproject.toml b/pyproject.toml index ea8f1a6ec..1c47e53a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "moscot" dynamic = ["version"] description = "Multi-omic single-cell optimal transport tools" readme = "README.rst" -requires-python = ">=3.9" +requires-python = ">=3.10" license = {file = "LICENSE"} classifiers = [ "Development Status :: 4 - Beta", @@ -232,7 +232,7 @@ ignore_roles = [ [tool.mypy] mypy_path = "$MYPY_CONFIG_FILE_DIR/src" -python_version = "3.9" +python_version = "3.10" plugins = "numpy.typing.mypy_plugin" ignore_errors = false @@ -269,7 +269,7 @@ max_line_length = 120 legacy_tox_ini = """ [tox] min_version = 4.0 -env_list = lint-code,py{3.9,3.10,3.11} +env_list = lint-code,py{3.10,3.11,3.12} skip_missing_interpreters = true [testenv] From 1a71ac77e173fc317b9e3e338ea8ef8aeb65495d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:01:33 +0000 Subject: [PATCH 14/25] [pre-commit.ci] pre-commit autoupdate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/pre-commit/mirrors-mypy: v1.10.1 → v1.11.2](https://github.com/pre-commit/mirrors-mypy/compare/v1.10.1...v1.11.2) - [github.com/psf/black: 24.4.2 → 24.8.0](https://github.com/psf/black/compare/24.4.2...24.8.0) - [github.com/asottile/pyupgrade: v3.16.0 → v3.17.0](https://github.com/asottile/pyupgrade/compare/v3.16.0...v3.17.0) - [github.com/rstcheck/rstcheck: v6.2.0 → v6.2.4](https://github.com/rstcheck/rstcheck/compare/v6.2.0...v6.2.4) - [github.com/PyCQA/doc8: v1.1.1 → v1.1.2](https://github.com/PyCQA/doc8/compare/v1.1.1...v1.1.2) - [github.com/astral-sh/ruff-pre-commit: v0.5.0 → v0.6.4](https://github.com/astral-sh/ruff-pre-commit/compare/v0.5.0...v0.6.4) --- .pre-commit-config.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84aa5a90c..584248dcc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,13 +7,13 @@ default_stages: minimum_pre_commit_version: 3.0.0 repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.1 + rev: v1.11.2 hooks: - id: mypy additional_dependencies: [numpy>=1.25.0] files: ^src - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black additional_dependencies: [toml] @@ -42,7 +42,7 @@ repos: - id: check-yaml - id: check-toml - repo: https://github.com/asottile/pyupgrade - rev: v3.16.0 + rev: v3.17.0 hooks: - id: pyupgrade args: [--py3-plus, --py38-plus, --keep-runtime-typing] @@ -52,18 +52,18 @@ repos: - id: blacken-docs additional_dependencies: [black==23.1.0] - repo: https://github.com/rstcheck/rstcheck - rev: v6.2.0 + rev: v6.2.4 hooks: - id: rstcheck additional_dependencies: [tomli] args: [--config=pyproject.toml] - repo: https://github.com/PyCQA/doc8 - rev: v1.1.1 + rev: v1.1.2 hooks: - id: doc8 - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.5.0 + rev: v0.6.4 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] From de8ac305888cc72e4d5a7685b5d68a7dc0d33554 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Sep 2024 21:02:06 +0000 Subject: [PATCH 15/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/backends/ott/test_backend.py | 4 +-- tests/conftest.py | 34 +++++++++---------- tests/datasets/test_dataset.py | 10 +++--- tests/plotting/conftest.py | 8 ++--- tests/problems/base/test_compound_problem.py | 12 +++---- tests/problems/base/test_general_problem.py | 2 +- tests/problems/conftest.py | 4 +-- tests/problems/cross_modality/test_mixins.py | 4 +-- .../test_translation_problem.py | 4 +-- tests/problems/generic/conftest.py | 2 +- .../test_conditional_neural_problem.py | 2 +- tests/problems/generic/test_fgw_problem.py | 10 +++--- tests/problems/generic/test_gw_problem.py | 10 +++--- .../problems/generic/test_sinkhorn_problem.py | 6 ++-- .../problems/space/test_alignment_problem.py | 4 +-- tests/problems/space/test_mapping_problem.py | 4 +-- tests/problems/space/test_mixins.py | 8 ++--- tests/problems/spatio_temporal/conftest.py | 2 +- .../test_spatio_temporal_problem.py | 10 +++--- tests/problems/time/conftest.py | 8 ++--- tests/problems/time/test_lineage_problem.py | 12 +++---- tests/problems/time/test_mixins.py | 14 ++++---- .../time/test_temporal_base_problem.py | 8 ++--- tests/problems/time/test_temporal_problem.py | 10 +++--- 24 files changed, 96 insertions(+), 96 deletions(-) diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index d0b8cbb60..3962c53d1 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -29,7 +29,7 @@ class TestSinkhorn: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize("eps", [None, 1e-2, 1e-1]) def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool): @@ -212,7 +212,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float assert isinstance(solver.xy, PointCloud) np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("alpha", [0.1, 0.9]) def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None: thresh, eps = 5e-2, 1e-1 diff --git a/tests/conftest.py b/tests/conftest.py index 95f13f5dc..0a0614630 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -39,7 +39,7 @@ def _close_figure(): plt.close() -@pytest.fixture() +@pytest.fixture def x() -> Geom_t: rng = np.random.RandomState(0) n = 20 # number of points in the first distribution @@ -51,7 +51,7 @@ def x() -> Geom_t: return jnp.asarray(xs) -@pytest.fixture() +@pytest.fixture def y() -> Geom_t: rng = np.random.RandomState(1) n2 = 30 # number of points in the second distribution @@ -63,7 +63,7 @@ def y() -> Geom_t: return jnp.asarray(xt) -@pytest.fixture() +@pytest.fixture def xy() -> Tuple[Geom_t, Geom_t]: rng = np.random.RandomState(2) n = 20 # number of points in the first distribution @@ -83,36 +83,36 @@ def xy() -> Tuple[Geom_t, Geom_t]: return jnp.asarray(ys), jnp.asarray(yt) -@pytest.fixture() +@pytest.fixture def ab() -> Tuple[np.ndarray, np.ndarray]: rng = np.random.RandomState(42) return rng.normal(size=(20, 2)), rng.normal(size=(30, 4)) -@pytest.fixture() +@pytest.fixture def x_cost(x: Geom_t) -> jnp.ndarray: return ((x[:, None, :] - x[None, ...]) ** 2).sum(-1) -@pytest.fixture() +@pytest.fixture def y_cost(y: Geom_t) -> jnp.ndarray: return ((y[:, None, :] - y[None, ...]) ** 2).sum(-1) -@pytest.fixture() +@pytest.fixture def xy_cost(xy: Geom_t) -> jnp.ndarray: x, y = xy return ((x[:, None, :] - y[None, ...]) ** 2).sum(-1) -@pytest.fixture() +@pytest.fixture def adata_x(x: Geom_t) -> AnnData: rng = np.random.RandomState(43) pc = rng.normal(size=(len(x), 4)) return AnnData(X=np.asarray(x, dtype=float), obsm={"X_pca": pc}) -@pytest.fixture() +@pytest.fixture def adata_y(y: Geom_t) -> AnnData: rng = np.random.RandomState(44) pc = rng.normal(size=(len(y), 4)) @@ -126,7 +126,7 @@ def creat_prob(n: int, *, uniform: bool = False, seed: Optional[int] = None) -> return jnp.asarray(a) -@pytest.fixture() +@pytest.fixture def adata_time() -> AnnData: rng = np.random.RandomState(42) @@ -156,7 +156,7 @@ def adata_time() -> AnnData: return adata -@pytest.fixture() +@pytest.fixture def gt_temporal_adata() -> AnnData: adata = _gt_temporal_adata.copy() # TODO(michalk8): remove both lines once data has been regenerated @@ -165,7 +165,7 @@ def gt_temporal_adata() -> AnnData: return adata -@pytest.fixture() +@pytest.fixture def adata_space_rotate() -> AnnData: rng = np.random.RandomState(31) grid = _make_grid(10) @@ -182,7 +182,7 @@ def adata_space_rotate() -> AnnData: return adata -@pytest.fixture() +@pytest.fixture def adata_mapping() -> AnnData: grid = _make_grid(10) adataref, adata1, adata2 = _make_adata(grid, n=3, seed=17, cat_key="covariate", num_categories=3) @@ -190,7 +190,7 @@ def adata_mapping() -> AnnData: return ad.concat([adataref, adata1, adata2], label="batch", join="outer", index_unique="-") -@pytest.fixture() +@pytest.fixture def adata_translation() -> AnnData: rng = np.random.RandomState(31) adatas = [AnnData(X=csr_matrix(rng.normal(size=(100, 60)))) for _ in range(3)] @@ -202,7 +202,7 @@ def adata_translation() -> AnnData: return adata -@pytest.fixture() +@pytest.fixture def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: rng = np.random.RandomState(15) adata_src = adata_translation[adata_translation.obs.batch != "0"].copy() @@ -212,7 +212,7 @@ def adata_translation_split(adata_translation) -> Tuple[AnnData, AnnData]: return adata_src, adata_tgt -@pytest.fixture() +@pytest.fixture def adata_anno( problem_kind: Literal["temporal", "cross_modality", "alignment", "mapping"], ) -> Union[AnnData, Tuple[AnnData, AnnData]]: @@ -258,7 +258,7 @@ def adata_anno( return adata -@pytest.fixture() +@pytest.fixture def gt_tm_annotation() -> np.ndarray: tm = np.zeros((10, 15)) for i in range(10): diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index f75655ed7..d02a9d286 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -9,7 +9,7 @@ class TestSimulateData: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("n_distributions", [2, 4]) @pytest.mark.parametrize("key", ["batch", "day"]) def test_n_distributions(self, n_distributions: int, key: str): @@ -17,7 +17,7 @@ def test_n_distributions(self, n_distributions: int, key: str): assert key in adata.obs.columns assert adata.obs[key].nunique() == n_distributions - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("obs_to_add", [{"celltype": 2}, {"celltype": 5, "cluster": 4}]) def test_obs_to_add(self, obs_to_add: Mapping[str, int]): adata = simulate_data(obs_to_add=obs_to_add) @@ -26,7 +26,7 @@ def test_obs_to_add(self, obs_to_add: Mapping[str, int]): assert colname in adata.obs.columns assert adata.obs[colname].nunique() == k - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("spatial_dim", [None, 2, 3]) def test_quad_term_spatial(self, spatial_dim: Optional[int]): kwargs = {} @@ -40,7 +40,7 @@ def test_quad_term_spatial(self, spatial_dim: Optional[int]): else: assert adata.obsm["spatial"].shape[1] == spatial_dim - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("n_intBCs", [None, 4, 7]) @pytest.mark.parametrize("barcode_dim", [None, 5, 8]) def test_quad_term_barcode(self, n_intBCs: Optional[int], barcode_dim: Optional[int]): @@ -63,7 +63,7 @@ def test_quad_term_barcode(self, n_intBCs: Optional[int], barcode_dim: Optional[ else: assert len(np.unique(adata.obsm["barcode"])) <= n_intBCs - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("n_initial_nodes", [None, 4, 7]) @pytest.mark.parametrize("n_distributions", [3, 6]) def test_quad_term_tree(self, n_initial_nodes: Optional[int], n_distributions: int): diff --git a/tests/plotting/conftest.py b/tests/plotting/conftest.py index 50ba864bf..baf0b04f8 100644 --- a/tests/plotting/conftest.py +++ b/tests/plotting/conftest.py @@ -24,7 +24,7 @@ DPI = 40 -@pytest.fixture() +@pytest.fixture def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: plot_vars = { "transition_matrix": gt_temporal_adata.uns["cell_transition_10_105_forward"], @@ -38,7 +38,7 @@ def adata_pl_cell_transition(gt_temporal_adata: AnnData) -> AnnData: return gt_temporal_adata -@pytest.fixture() +@pytest.fixture def adata_pl_push(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} @@ -57,7 +57,7 @@ def adata_pl_push(adata_time: AnnData) -> AnnData: return adata_time -@pytest.fixture() +@pytest.fixture def adata_pl_pull(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) plot_vars = {"key": "time", "data": "celltype", "subset": "A", "source": 0, "target": 1} @@ -75,7 +75,7 @@ def adata_pl_pull(adata_time: AnnData) -> AnnData: return adata_time -@pytest.fixture() +@pytest.fixture def adata_pl_sankey(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(0) celltypes = ["A", "B", "C", "D", "E"] diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 12a75a69f..5810f3349 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -68,7 +68,7 @@ def test_sc_pipeline(self, adata_time: AnnData): assert problem[key].solution is problem.solutions[key] @pytest.mark.parametrize("scale", [True, False]) - @pytest.mark.fast() + @pytest.mark.fast def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scale: bool): subproblem = OTProblem(adata_time, adata_tgt=adata_time.copy()) xy_callback_kwargs = {"n_comps": 5, "scale": scale} @@ -88,7 +88,7 @@ def test_default_callback(self, adata_time: AnnData, mocker: MockerFixture, scal assert isinstance(problem.problems, dict) spy.assert_called_with("xy", subproblem.adata_src, subproblem.adata_tgt, **xy_callback_kwargs) - @pytest.mark.fast() + @pytest.mark.fast def test_custom_callback_lin(self, adata_time: AnnData, mocker: MockerFixture): expected_keys = [(0, 1), (1, 2)] spy = mocker.spy(TestCompoundProblem, "xy_callback") @@ -106,7 +106,7 @@ def test_custom_callback_lin(self, adata_time: AnnData, mocker: MockerFixture): assert spy.call_count == len(expected_keys) - @pytest.mark.fast() + @pytest.mark.fast def test_custom_callback_quad(self, adata_time: AnnData, mocker: MockerFixture): expected_keys = [(0, 1), (1, 2)] spy_x = mocker.spy(TestCompoundProblem, "x_callback") @@ -164,7 +164,7 @@ def test_different_passings_linear(self, adata_with_cost_matrix: AnnData): np.testing.assert_allclose(gt.matrix, p1_tmap, rtol=RTOL, atol=ATOL) np.testing.assert_allclose(gt.matrix, p2_tmap, rtol=RTOL, atol=ATOL) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)]) def test_prepare_cost(self, adata_time: AnnData, cost: Tuple[str, Any]): problem = Problem(adata=adata_time) @@ -179,7 +179,7 @@ def test_prepare_cost(self, adata_time: AnnData, cost: Tuple[str, Any]): assert isinstance(problem[0, 1].x.cost, cost[1]) assert isinstance(problem[0, 1].y.cost, cost[1]) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("cost", [("sq_euclidean", SqEuclidean), ("euclidean", Euclidean), ("cosine", Cosine)]) def test_prepare_cost_with_callback(self, adata_time: AnnData, cost: Tuple[str, Any]): problem = Problem(adata=adata_time) @@ -196,7 +196,7 @@ def test_prepare_cost_with_callback(self, adata_time: AnnData, cost: Tuple[str, assert isinstance(problem[0, 1].x.cost, cost[1]) assert isinstance(problem[0, 1].y.cost, cost[1]) - @pytest.mark.fast() + @pytest.mark.fast def test_prepare_different_costs(self, adata_time: AnnData): problem = Problem(adata=adata_time) problem = problem.prepare( diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index 431a87d93..5e17b6dec 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -29,7 +29,7 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): assert isinstance(prob.solution, BaseDiscreteSolverOutput) - @pytest.mark.fast() + @pytest.mark.fast def test_output(self, adata_x: AnnData, x: Geom_t): problem = OTProblem(adata_x) problem._solution = MockSolverOutput(x * x.T) diff --git a/tests/problems/conftest.py b/tests/problems/conftest.py index 99ae64984..9ee32e693 100644 --- a/tests/problems/conftest.py +++ b/tests/problems/conftest.py @@ -10,7 +10,7 @@ from tests._utils import Geom_t -@pytest.fixture() +@pytest.fixture def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData: adata = ad.concat([adata_x, adata_y], label="batch", index_unique="-") C = pairwise_distances(adata_x.obsm["X_pca"], adata_y.obsm["X_pca"]) ** 2 @@ -19,7 +19,7 @@ def adata_with_cost_matrix(adata_x: Geom_t, adata_y: Geom_t) -> AnnData: return adata -@pytest.fixture() +@pytest.fixture def adata_time_with_tmap(adata_time: AnnData) -> AnnData: adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy() rng = np.random.RandomState(42) diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 5085819ab..8bf7c3565 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -65,7 +65,7 @@ def test_translate_alternative( trans_backward = tp.translate(source=src, target=tgt, forward=False, alternative_attr=alternative_attr) assert trans_backward.shape == adata_src[adata_src.obs["batch"] == "1"].obsm["X_pca"].shape - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("normalize", [True, False]) def test_cell_transition_pipeline( @@ -107,7 +107,7 @@ def test_cell_transition_pipeline( with pytest.raises(AssertionError): pd.testing.assert_frame_equal(result1, result2) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("batch_size", [3, 7, None]) diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index b083ee745..b048fa9cf 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -25,7 +25,7 @@ class TestTranslationProblem: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}]) @pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}]) @pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}]) @@ -50,7 +50,7 @@ def test_prepare_dummy_policy( assert tp[prob_key].shape == (2 * n_obs, n_obs) np.testing.assert_array_equal(tp._policy._cat, prob_key) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}]) @pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}]) @pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}]) diff --git a/tests/problems/generic/conftest.py b/tests/problems/generic/conftest.py index db43f57b7..9be929fb1 100644 --- a/tests/problems/generic/conftest.py +++ b/tests/problems/generic/conftest.py @@ -6,7 +6,7 @@ from anndata import AnnData -@pytest.fixture() +@pytest.fixture def adata_time_with_tmap(adata_time: AnnData) -> AnnData: adata = adata_time[adata_time.obs["time"].isin([0, 1])].copy() rng = np.random.RandomState(42) diff --git a/tests/problems/generic/test_conditional_neural_problem.py b/tests/problems/generic/test_conditional_neural_problem.py index 15dba8146..5a7297de0 100644 --- a/tests/problems/generic/test_conditional_neural_problem.py +++ b/tests/problems/generic/test_conditional_neural_problem.py @@ -15,7 +15,7 @@ class TestGENOTLinProblem: - @pytest.mark.fast() + @pytest.mark.fast def test_prepare(self, adata_time: ad.AnnData): problem = GENOTLinProblem(adata=adata_time) problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 35a5efca2..3ee7b4f30 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -30,7 +30,7 @@ class TestFGWProblem: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("policy", ["sequential", "star"]) def test_prepare(self, adata_space_rotate: AnnData, policy): expected_keys = { @@ -59,7 +59,7 @@ def test_prepare(self, adata_space_rotate: AnnData, policy): assert key in expected_keys[policy] assert isinstance(problem[key], OTProblem) - @pytest.mark.fast() + @pytest.mark.fast def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): problem = FGWProblem(adata=adata_time) problem = problem.prepare( @@ -173,7 +173,7 @@ def test_set_xy(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"] assert isinstance(problem[0, 1].xy.data_src, np.ndarray) assert problem[0, 1].xy.data_tgt is None - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), [ @@ -212,7 +212,7 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, problem = problem.solve(max_iterations=2) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), [ @@ -305,7 +305,7 @@ def test_set_y(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]) assert isinstance(problem[0, 1].y.data_src, np.ndarray) assert problem[0, 1].y.data_tgt is None - @pytest.mark.fast() + @pytest.mark.fast def test_prepare_different_costs(self, adata_time: AnnData): problem = FGWProblem(adata=adata_time) problem = problem.prepare( diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index ccee4323f..7ae0cb7e5 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -28,7 +28,7 @@ class TestGWProblem: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( "policy", ["sequential", "star"], @@ -147,7 +147,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), [ @@ -176,7 +176,7 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, problem = problem.solve(max_iterations=2) - @pytest.mark.fast() + @pytest.mark.fast def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): problem = GWProblem(adata=adata_time) problem = problem.prepare( @@ -185,7 +185,7 @@ def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): for key in problem: _assert_marginals_set(adata_time, problem, key, marginal_keys) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), [ @@ -273,7 +273,7 @@ def test_set_y(self, adata_time: AnnData, tag: Literal["cost_matrix", "kernel"]) assert isinstance(problem[0, 1].y.data_src, np.ndarray) assert problem[0, 1].y.data_tgt is None - @pytest.mark.fast() + @pytest.mark.fast def test_prepare_different_costs(self, adata_time: AnnData): problem = GWProblem(adata=adata_time) problem = problem.prepare( diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index d75d37f2d..1badbf49b 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -27,7 +27,7 @@ class TestSinkhornProblem: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("policy", ["sequential", "star"]) def test_prepare(self, adata_time: AnnData, policy, marginal_keys): expected_keys = {"sequential": [(0, 1), (1, 2)], "star": [(1, 0), (2, 0)]} @@ -61,7 +61,7 @@ def test_solve_balanced(self, adata_time: AnnData, marginal_keys): assert np.allclose(subsol.a, problem[key].a, atol=1e-5) assert np.allclose(subsol.b, problem[key].b, atol=1e-5) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), [ @@ -83,7 +83,7 @@ def test_prepare_costs(self, adata_time: AnnData, cost_str: str, cost_inst: Any, problem = problem.solve(max_iterations=2) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( ("cost_str", "cost_inst", "cost_kwargs"), [ diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 0ef586bc0..0d1b21a16 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -31,7 +31,7 @@ class TestAlignmentProblem: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("joint_attr", [{"attr": "X"}]) @pytest.mark.parametrize("normalize_spatial", [True, False]) def test_prepare_sequential( @@ -62,7 +62,7 @@ def test_prepare_sequential( assert ap[prob_key].x.data_src.shape == ap[prob_key].y.data_src.shape == (n_obs, 2) assert ap[prob_key].xy.data_src.shape == ap[prob_key].xy.data_tgt.shape == (n_obs, n_var) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("reference", ["0", "1", "2"]) def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): ap = AlignmentProblem(adata=adata_space_rotate) diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 485ea52ff..a51bf4ed9 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -35,7 +35,7 @@ class TestMappingProblem: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("sc_attr", [{"attr": "X"}, {"attr": "obsm", "key": "X_pca"}]) @pytest.mark.parametrize("joint_attr", [None, "X_pca", {"attr": "obsm", "key": "X_pca"}]) @pytest.mark.parametrize("normalize_spatial", [True, False]) @@ -79,7 +79,7 @@ def test_prepare( assert mp[prob_key].shape == (2 * n_obs, n_obs) np.testing.assert_array_equal(mp._policy._cat, prob_key) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("var_names", ["0", [], [str(i) for i in range(20)]]) def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List[str]]): adataref, adatasp = _adata_spatial_split(adata_mapping) diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index b9c2051c1..a4f0145ea 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -65,7 +65,7 @@ def test_regression_testing(self, adata_space_rotate: AnnData): np.array(sol[k].transport_matrix), np.array(ap.solutions[k].transport_matrix), decimal=3 ) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("normalize", [True, False]) def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bool, normalize: bool): @@ -93,7 +93,7 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo assert isinstance(result, pd.DataFrame) assert result.shape == (3, 3) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("batch_size", [3, 7, None]) @@ -187,7 +187,7 @@ def test_regression_testing(self, adata_mapping: AnnData): np.array(sol[k].transport_matrix), np.array(mp.solutions[k].transport_matrix), decimal=3 ) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("normalize", [True, False]) def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, normalize: bool): @@ -215,7 +215,7 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n assert isinstance(result, pd.DataFrame) assert result.shape == (3, 4) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("batch_size", [3, 7, None]) diff --git a/tests/problems/spatio_temporal/conftest.py b/tests/problems/spatio_temporal/conftest.py index bc19bcd8a..e29c381a1 100644 --- a/tests/problems/spatio_temporal/conftest.py +++ b/tests/problems/spatio_temporal/conftest.py @@ -7,7 +7,7 @@ from tests._utils import _make_grid -@pytest.fixture() +@pytest.fixture def adata_spatio_temporal(adata_time: AnnData) -> AnnData: _, t_unique_counts = np.unique(adata_time.obs["time"], return_counts=True) grids = [] diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index e7d35c82c..3b68561ba 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -27,7 +27,7 @@ class TestSpatioTemporalProblem: - @pytest.mark.fast() + @pytest.mark.fast def test_prepare(self, adata_spatio_temporal: AnnData): expected_keys = [(0, 1), (1, 2)] problem = SpatioTemporalProblem(adata=adata_spatio_temporal) @@ -88,7 +88,7 @@ def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): div2 = np.linalg.norm(problem1[0, 1].b - problem1[0, 1].solution.b) assert div1 < div2 - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( "gene_set_list", [ @@ -120,7 +120,7 @@ def test_score_genes(self, adata_spatio_temporal: AnnData, gene_set_list: List[L else: assert problem.apoptosis_key is None - @pytest.mark.fast() + @pytest.mark.fast def test_proliferation_key_pipeline(self, adata_spatio_temporal: AnnData): problem = SpatioTemporalProblem(adata_spatio_temporal) assert problem.proliferation_key is None @@ -132,7 +132,7 @@ def test_proliferation_key_pipeline(self, adata_spatio_temporal: AnnData): problem.proliferation_key = "new_proliferation" assert problem.proliferation_key == "new_proliferation" - @pytest.mark.fast() + @pytest.mark.fast def test_apoptosis_key_pipeline(self, adata_spatio_temporal: AnnData): problem = SpatioTemporalProblem(adata_spatio_temporal) assert problem.apoptosis_key is None @@ -144,7 +144,7 @@ def test_apoptosis_key_pipeline(self, adata_spatio_temporal: AnnData): problem.apoptosis_key = "new_apoptosis" assert problem.apoptosis_key == "new_apoptosis" - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("scaling", [0.1, 1, 4]) def test_proliferation_key_c_pipeline(self, adata_spatio_temporal: AnnData, scaling: float): key0, key1, *_ = np.sort(np.unique(adata_spatio_temporal.obs["time"].values)) diff --git a/tests/problems/time/conftest.py b/tests/problems/time/conftest.py index 8975411ac..4b69cd4c7 100644 --- a/tests/problems/time/conftest.py +++ b/tests/problems/time/conftest.py @@ -8,7 +8,7 @@ from moscot.datasets import _get_random_trees -@pytest.fixture() +@pytest.fixture def adata_time_trees(adata_time: AnnData) -> AnnData: trees = _get_random_trees( n_leaves=96, n_trees=3, leaf_names=[list(adata_time[adata_time.obs.time == i].obs.index) for i in range(3)] @@ -17,7 +17,7 @@ def adata_time_trees(adata_time: AnnData) -> AnnData: return adata_time -@pytest.fixture() +@pytest.fixture def adata_time_custom_cost_xy(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(42) cost_m1 = np.abs(rng.randn(96, 96)) @@ -27,14 +27,14 @@ def adata_time_custom_cost_xy(adata_time: AnnData) -> AnnData: return adata_time -@pytest.fixture() +@pytest.fixture def adata_time_barcodes(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(42) adata_time.obsm["barcodes"] = rng.randn(len(adata_time), 30) return adata_time -@pytest.fixture() +@pytest.fixture def adata_time_marginal_estimations(adata_time: AnnData) -> AnnData: rng = np.random.RandomState(42) adata_time.obs["proliferation"] = rng.randn(len(adata_time)) diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index fa79fbb74..5375c621e 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -26,7 +26,7 @@ class TestLineageProblem: - @pytest.mark.fast() + @pytest.mark.fast def test_prepare(self, adata_time_barcodes: AnnData): expected_keys = [(0, 1), (1, 2)] problem = LineageProblem(adata=adata_time_barcodes) @@ -94,7 +94,7 @@ def test_solve_unbalanced(self, adata_time_barcodes: AnnData): div2 = np.linalg.norm(problem1[0, 1].b - problem1[0, 1].solution.b) assert div1 < div2 - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( "gene_set_list", [ @@ -126,7 +126,7 @@ def test_score_genes(self, adata_time_barcodes: AnnData, gene_set_list: List[Lis else: assert problem.apoptosis_key is None - @pytest.mark.fast() + @pytest.mark.fast def test_proliferation_key_pipeline(self, adata_time_barcodes: AnnData): problem = LineageProblem(adata_time_barcodes) assert problem.proliferation_key is None @@ -138,7 +138,7 @@ def test_proliferation_key_pipeline(self, adata_time_barcodes: AnnData): problem.proliferation_key = "new_proliferation" assert problem.proliferation_key == "new_proliferation" - @pytest.mark.fast() + @pytest.mark.fast def test_apoptosis_key_pipeline(self, adata_time_barcodes: AnnData): problem = LineageProblem(adata_time_barcodes) assert problem.apoptosis_key is None @@ -150,7 +150,7 @@ def test_apoptosis_key_pipeline(self, adata_time_barcodes: AnnData): problem.apoptosis_key = "new_apoptosis" assert problem.apoptosis_key == "new_apoptosis" - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("scaling", [0.1, 1, 4]) def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scaling: float): key0, key1, *_ = np.sort(np.unique(adata_time_barcodes.obs["time"].values)) @@ -173,7 +173,7 @@ def test_proliferation_key_c_pipeline(self, adata_time_barcodes: AnnData, scalin expected_marginals = np.exp((prolif - apopt) * delta / scaling) np.testing.assert_allclose(problem[key0, key1]._prior_growth, expected_marginals, rtol=RTOL, atol=ATOL) - @pytest.mark.fast() + @pytest.mark.fast def test_barcodes_pipeline(self, adata_time_barcodes: AnnData): expected_keys = [(0, 1), (1, 2)] problem = LineageProblem(adata=adata_time_barcodes) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index ebb7f950f..bc116a704 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -13,7 +13,7 @@ class TestTemporalMixin: - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward: bool): config = gt_temporal_adata.uns @@ -51,7 +51,7 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward present_cell_type_marginal = marginal[marginal > 0] np.testing.assert_allclose(present_cell_type_marginal, 1.0) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) @pytest.mark.parametrize("batch_size", [3, 7, None]) @@ -81,7 +81,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo ) assert (result[annotation_label] == expected_result).all() - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) def test_cell_transition_different_groups(self, gt_temporal_adata: AnnData, forward: bool): config = gt_temporal_adata.uns @@ -111,7 +111,7 @@ def test_cell_transition_different_groups(self, gt_temporal_adata: AnnData, forw assert set(result.index) == cell_types assert set(result.columns) == batches - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("forward", [True, False]) def test_cell_transition_subset_pipeline(self, gt_temporal_adata: AnnData, forward: bool): config = gt_temporal_adata.uns @@ -339,7 +339,7 @@ def test_compute_random_distance_regression(self, gt_temporal_adata: AnnData): np.testing.assert_allclose(result, gt_temporal_adata.uns["random_distance_10_105_11"], rtol=1e-6, atol=1e-6) # TODO(MUCDK): split into 2 tests - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("only_start", [True, False]) def test_get_data_pipeline(self, adata_time: AnnData, only_start: bool): problem = TemporalProblem(adata_time) @@ -379,7 +379,7 @@ def test_get_interp_param_pipeline(self, adata_time: AnnData, time_points: Tuple inter_param = problem._get_interp_param(start, intermediate, end, interpolation_parameter) assert inter_param == 0.5 - @pytest.mark.fast() + @pytest.mark.fast def test_cell_transition_regression_notparam( self, adata_time_with_tmap: AnnData, @@ -399,7 +399,7 @@ def test_cell_transition_regression_notparam( # TODO(MUCDK): use pandas.testing np.testing.assert_allclose(res.values, df_expected.values, rtol=1e-6, atol=1e-6) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("temporal_key", ["celltype", "time", "missing"]) def test_temporal_key_numeric(self, adata_time: AnnData, temporal_key: str): problem = TemporalProblem(adata_time) diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index 880fbb7dd..abdfe0f90 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -11,7 +11,7 @@ # TODO(@MUCDK) put file in different folder according to moscot.problems structure class TestBirthDeathProblem: - @pytest.mark.fast() + @pytest.mark.fast def test_initialization_pipeline(self, adata_time_marginal_estimations: AnnData): t1, t2 = 0, 1 adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] @@ -35,7 +35,7 @@ def test_initialization_pipeline(self, adata_time_marginal_estimations: AnnData) assert isinstance(prob.b, np.ndarray) # TODO(MUCDK): break this test - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( "adata_obs_keys", [ @@ -77,7 +77,7 @@ def test_estimate_marginals_pipeline( if not source: assert len(np.unique(a_estimated)) == 1 - @pytest.mark.fast() + @pytest.mark.fast def test_prior_growth_rates(self, adata_time_marginal_estimations: AnnData): t1, t2 = 0, 1 adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] @@ -117,7 +117,7 @@ def test_posterior_growth_rates(self, adata_time_marginal_estimations: AnnData): gr = prob.posterior_growth_rates assert isinstance(gr, np.ndarray) - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( "marginal_kwargs", [{}, {"delta_width": 0.9}, {"delta_center": 0.9}, {"beta_width": 0.9}, {"beta_center": 0.9}] ) diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 5a4f92df7..f53a745c9 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -31,7 +31,7 @@ class TestTemporalProblem: - @pytest.mark.fast() + @pytest.mark.fast def test_prepare(self, adata_time: AnnData): expected_keys = [(0, 1), (1, 2)] problem = TemporalProblem(adata=adata_time) @@ -90,7 +90,7 @@ def test_solve_unbalanced(self, adata_time: AnnData): div2 = np.linalg.norm(problem1[0, 1].b - problem1[0, 1].solution.b) assert div1 < div2 - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize( "gene_set_list", [ @@ -122,7 +122,7 @@ def test_score_genes(self, adata_time: AnnData, gene_set_list: List[List[str]]): else: assert problem.apoptosis_key is None - @pytest.mark.fast() + @pytest.mark.fast def test_proliferation_key_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata_time) assert problem.proliferation_key is None @@ -134,7 +134,7 @@ def test_proliferation_key_pipeline(self, adata_time: AnnData): problem.proliferation_key = "new_proliferation" assert problem.proliferation_key == "new_proliferation" - @pytest.mark.fast() + @pytest.mark.fast def test_apoptosis_key_pipeline(self, adata_time: AnnData): problem = TemporalProblem(adata_time) assert problem.apoptosis_key is None @@ -146,7 +146,7 @@ def test_apoptosis_key_pipeline(self, adata_time: AnnData): problem.apoptosis_key = "new_apoptosis" assert problem.apoptosis_key == "new_apoptosis" - @pytest.mark.fast() + @pytest.mark.fast @pytest.mark.parametrize("scaling", [0.1, 1, 4]) def test_proliferation_key_c_pipeline(self, adata_time: AnnData, scaling: float): key0, key1, *_ = np.sort(np.unique(adata_time.obs["time"].values)) From 337577c9b527c2ad7732af8cd1b7a5abca4b4938 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 15:16:30 +0200 Subject: [PATCH 16/25] remove unused function --- src/moscot/backends/ott/solver.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index b4408cf8b..5a66d47a4 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -22,7 +22,6 @@ import jax import jax.numpy as jnp import numpy as np -import scipy.sparse as sp from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud from ott.neural.datasets import OTData, OTDataset from ott.neural.methods.flows import dynamics, genot @@ -652,15 +651,6 @@ def _prepare( # type: ignore[override] MultiLoader(datasets=validate_loaders, seed=seed), ) - @staticmethod - def _assert2d(arr: ArrayLike, *, allow_reshape: bool = True) -> jnp.ndarray: - arr: jnp.ndarray = jnp.asarray(arr.A if sp.issparse(arr) else arr) # type: ignore[no-redef, attr-defined] # noqa:E501 - if allow_reshape and arr.ndim == 1: - return jnp.reshape(arr, (-1, 1)) - if arr.ndim != 2: - raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.") - return arr - def _split_data( # TODO: adapt for Gromov terms self, x: ArrayLike, From 2e87abea310c91839a10a5a1250ba0e5ba760b98 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 15:37:03 +0200 Subject: [PATCH 17/25] asses mypy errors --- src/moscot/_types.py | 4 ++-- src/moscot/backends/ott/_utils.py | 2 +- src/moscot/backends/ott/output.py | 4 ++-- src/moscot/backends/utils.py | 3 +-- src/moscot/base/cost.py | 4 ++-- src/moscot/base/output.py | 2 +- src/moscot/base/problems/_mixins.py | 4 ++-- src/moscot/base/problems/_utils.py | 2 +- src/moscot/base/problems/problem.py | 2 +- src/moscot/costs/_costs.py | 6 +++--- src/moscot/plotting/_utils.py | 2 +- src/moscot/problems/generic/_generic.py | 4 ++-- src/moscot/problems/space/_mapping.py | 4 ++-- src/moscot/problems/space/_mixins.py | 6 +++--- src/moscot/problems/time/_lineage.py | 2 +- src/moscot/problems/time/_mixins.py | 6 +++--- src/moscot/utils/tagged_array.py | 2 +- 17 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/moscot/_types.py b/src/moscot/_types.py index 1c60884a2..c3f599d2d 100644 --- a/src/moscot/_types.py +++ b/src/moscot/_types.py @@ -10,8 +10,8 @@ ArrayLike = NDArray[np.float64] except (ImportError, TypeError): - ArrayLike = np.ndarray # type: ignore[misc] - DTypeLike = np.dtype # type: ignore[misc] + ArrayLike = np.ndarray + DTypeLike = np.dtype ProblemKind_t = Literal["linear", "quadratic", "unknown"] Numeric_t = Union[int, float] # type of `time_key` arguments diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 2cac53b30..42cc87af5 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -108,7 +108,7 @@ def densify(arr: ArrayLike) -> jax.Array: dense :mod:`jax` array. """ if sp.issparse(arr): - arr = arr.toarray() # type: ignore[attr-defined] + arr = arr.toarray() elif isinstance(arr, jesp.BCOO): arr = arr.todense() return jnp.asarray(arr) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index 60d727faf..c50850ec3 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -375,8 +375,8 @@ def project_to_transport_matrix( # type:ignore[override] The projected transport matrix. """ src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells) - push = self.push if condition is None else lambda x: self.push(x, condition) - pull = self.pull if condition is None else lambda x: self.pull(x, condition) + push: Callable[[Any], Any] = self.push if condition is None else lambda x: self.push(x, condition) # type: ignore + pull: Callable[[Any], Any] = self.pull if condition is None else lambda x: self.pull(x, condition) # type: ignore func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells) return self._project_transport_matrix( src_dist=src_dist, diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index fde874c0f..988e05413 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -42,8 +42,7 @@ def register_solver( return _REGISTRY.register(backend) # type: ignore[return-value] -# TODO(@MUCDK) fix mypy error -@register_solver("ott") # type: ignore[arg-type] +@register_solver("ott") def _( problem_kind: Literal["linear", "quadratic"], solver_name: Optional[Literal["GENOTLinSolver"]] = None, diff --git a/src/moscot/base/cost.py b/src/moscot/base/cost.py index 8bea310be..73adfe8c9 100644 --- a/src/moscot/base/cost.py +++ b/src/moscot/base/cost.py @@ -53,12 +53,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike: """ cost = self._compute(*args, **kwargs) if np.any(np.isnan(cost)): - maxx = np.nanmax(cost) + maxx = np.nanmax(cost) # type: ignore[var-annotated] logger.warning( f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, " f"setting them to the maximum value `{maxx}`." ) - cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload] + cost = np.nan_to_num(cost, nan=maxx) if np.any(cost < 0): raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.") return cost diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 9617b8729..3565c5790 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -354,7 +354,7 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102 @property def shape(self) -> tuple[int, int]: # noqa: D102 - return self.transport_matrix.shape # type: ignore[return-value] + return self.transport_matrix.shape def to( # noqa: D102 self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 9c460eac2..8d27cd177 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -388,7 +388,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[Numeric_t] = None, seed: Optional[int] = None, - ) -> tuple[list[Any], list[ArrayLike]]: + ) -> tuple[Any, list[str]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -453,7 +453,7 @@ def _sample_from_tmap( for i in range(len(rows_batch)) ] all_cols_sampled.extend(cols_sampled) - return rows, all_cols_sampled # type: ignore[return-value] + return rows, all_cols_sampled def _interpolate_transport( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 40bf0a99a..482f79b75 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -386,7 +386,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr corr_bs = np.concatenate(corr_bs, axis=0) corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile(corr_bs, q=qh, axis=0) - return pvals, corr_ci_low, corr_ci_high # type:ignore[return-value] + return pvals, corr_ci_low, corr_ci_high if not (0 <= confidence_level <= 1): raise ValueError(f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`.") diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 708b35d44..93a3e24d0 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -183,7 +183,7 @@ def _split_mass(arr: ArrayLike) -> ArrayLike: if start >= adata.n_obs: raise IndexError(f"Expected starting index to be smaller than `{adata.n_obs}`, found `{start}`.") data = np.zeros((adata.n_obs,), dtype=float) - data[range(start, min(start + offset, adata.n_obs))] = 1.0 + data[range(start, min(start + offset, adata.n_obs))] = 1.0 # type: ignore[index] else: raise TypeError(f"Unable to interpret subset of type `{type(subset)}`.") elif not hasattr(data, "shape"): diff --git a/src/moscot/costs/_costs.py b/src/moscot/costs/_costs.py index f41c0f7f2..fd79b8a71 100644 --- a/src/moscot/costs/_costs.py +++ b/src/moscot/costs/_costs.py @@ -138,7 +138,7 @@ def _scaled_hamming_dist(x: ArrayLike, y: ArrayLike) -> float: raise ValueError("No shared indices.") b2 = y[shared_indices] - differences = b1 != b2 - double_scars = differences & (b1 != 0) & (b2 != 0) + differences: ArrayLike = b1 != b2 + double_scars: ArrayLike = differences & (b1 != 0) & (b2 != 0) - return (np.sum(differences) + np.sum(double_scars)) / len(b1) + return float(float(np.sum(differences)) + np.sum(double_scars)) / len(b1) diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 358ae5fa2..7047945b0 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -463,7 +463,7 @@ def _plot_scatter( _ = kwargs.pop("palette", None) if (time_points[i] == source and push) or (time_points[i] == target and not push): st = f"not in {time_points[i]}" - vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) + vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) # type: ignore[var-annotated] column = pd.Series(tmp).fillna(st).astype("category") # TODO(michalk8): check if len(np.unique(column[mask.values].values)) > 2: diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 8d727d0d0..c5b885dcc 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -37,7 +37,7 @@ def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, s raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.") -class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] +class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): """Class for solving a :term:`linear problem`. Parameters @@ -264,7 +264,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] +class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): """Class for solving the :term:`GW ` or :term:`FGW ` problems. Parameters diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index abe440e28..dcb74d60a 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -65,8 +65,8 @@ def _create_problem( adata_tgt=self.adata_sc, src_obs_mask=src_mask, tgt_obs_mask=None, - src_var_mask=self.filtered_vars, # type: ignore[arg-type] - tgt_var_mask=self.filtered_vars, # type: ignore[arg-type] + src_var_mask=self.filtered_vars, + tgt_var_mask=self.filtered_vars, src_key=src, tgt_key=tgt, **kwargs, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index d72c79c2b..a182999f1 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -42,7 +42,7 @@ class SpatialAlignmentMixinProtocol(AnalysisMixinProtocol[K, B]): _spatial_key: Optional[str] batch_key: Optional[str] - def _subset_spatial( # type:ignore[empty-body] + def _subset_spatial( self: "SpatialAlignmentMixinProtocol[K, B]", k: K, spatial_key: str, @@ -780,13 +780,13 @@ def _compute_correspondence( def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any: if len(row_idx) > 0: - return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() # type: ignore[index] + return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() return np.nan # TODO(michalk8): vectorize using jax, this is just a for loop vpdist = np.vectorize(pdist, excluded=["feat"]) if sp.issparse(features): - features = features.toarray() # type: ignore[attr-defined] + features = features.toarray() feat_arr, index_arr, support_arr = [], [], [] for ind, i in enumerate(support): diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 809b7f8b9..b20ee5e11 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -23,7 +23,7 @@ __all__ = ["TemporalProblem", "LineageProblem"] -class TemporalProblem( # type: ignore[misc] +class TemporalProblem( TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem] ): """Class for analyzing time-series single cell data based on :cite:`schiebinger:19`. diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 29a44aed0..3d565bdec 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -587,7 +587,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.asarray(problem.solution.potentials[0]), # type: ignore[union-attr,index] + np.asarray(problem.solution.potentials[0]), index=problem.adata_src.obs_names, columns=cols, ) @@ -612,7 +612,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.array(problem.solution.potentials[1]), # type: ignore[union-attr,index] + np.array(problem.solution.potentials[1]), index=problem.adata_tgt.obs_names, columns=cols, ) @@ -664,7 +664,7 @@ def _get_data( else: raise ValueError(f"No data found for `{target}` time point.") - return ( # type:ignore[return-value] + return ( source_data, growth_rates_source, intermediate_data, diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 050573534..5a1ed5781 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -167,7 +167,7 @@ def shape(self) -> Tuple[int, int]: x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt) return x.shape[0], y.shape[0] - return self.data_src.shape # type: ignore[return-value] + return self.data_src.shape @property def is_cost_matrix(self) -> bool: From c8df47ca51f748d55491511ac7eaedfd60341618 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 16:02:21 +0200 Subject: [PATCH 18/25] handle redirected links for CI --- docs/conf.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/conf.py b/docs/conf.py index 311d1f6b2..8e155d798 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -140,6 +140,12 @@ r"https://doi.org/10.1093/nar/gkac235", r"https://www.science.org/doi/abs/10.1126/science.aax1971", ] +linkcheck_ignore_redirects = { + r"https://doi.org/10.1101/2022.01.10.475692": r"https://www.biorxiv.org/lookup/doi/10.1101/2022.01.10.475692", + r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1": r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1", + r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2": r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2", + r"https://www.biorxiv.org/content/early/2022/01/11/2022.01.10.475692": r"https://www.biorxiv.org/lookup/doi/10.1101/2022.01.10.475692", +} exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"] From a75139f98b312dba982e576b98a8fd3659d7a6bf Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 16:15:20 +0200 Subject: [PATCH 19/25] add problematic links to ignore --- docs/conf.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 8e155d798..289640f5f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -139,13 +139,11 @@ r"https://doi.org/10.1126/science.aax1971", r"https://doi.org/10.1093/nar/gkac235", r"https://www.science.org/doi/abs/10.1126/science.aax1971", + r"https://doi.org/10.1101/2022.01.10.475692", + r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1", + r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2", + r"https://www.biorxiv.org/content/early/2022/01/11/2022.01.10.475692", ] -linkcheck_ignore_redirects = { - r"https://doi.org/10.1101/2022.01.10.475692": r"https://www.biorxiv.org/lookup/doi/10.1101/2022.01.10.475692", - r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1": r"https://www.biorxiv.org/content/10.1101/2023.04.14.536867v1", - r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2": r"https://www.biorxiv.org/content/10.1101/2023.05.11.540374v2", - r"https://www.biorxiv.org/content/early/2022/01/11/2022.01.10.475692": r"https://www.biorxiv.org/lookup/doi/10.1101/2022.01.10.475692", -} exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"] From 27a857499445606d744b4a95906f6b91d23e15f9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 16:38:03 +0200 Subject: [PATCH 20/25] fix _get_array_data --- src/moscot/base/solver.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index 0e69bd404..45949337e 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -55,9 +55,6 @@ def to_tuple( loss_x = {k[2:]: v for k, v in kwargs.items() if k.startswith("x_")} loss_y = {k[2:]: v for k, v in kwargs.items() if k.startswith("y_")} - if isinstance(xy, dict) and np.all([isinstance(v, tuple) for v in xy.values()]): # handling joint learning - return xy - # fmt: off xy = xy if isinstance(xy, TaggedArray) else self._convert(*to_tuple(xy), tag=tags.get("xy", None), **loss_xy) x = x if isinstance(x, TaggedArray) else self._convert(*to_tuple(x), tag=tags.get("x", None), **loss_x) From de93eef8027652d15d552fb21f198f82f6de22b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:39:23 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/moscot/base/solver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index 45949337e..155584fcf 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -14,7 +14,6 @@ Union, ) -import numpy as np from moscot._logging import logger from moscot._types import ArrayLike, Device_t, ProblemKind_t From 22aaccfd4aa090f96a8da97b90f783c953c9f800 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 16:45:52 +0200 Subject: [PATCH 22/25] formatting --- src/moscot/base/solver.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/moscot/base/solver.py b/src/moscot/base/solver.py index 155584fcf..d4ec22360 100644 --- a/src/moscot/base/solver.py +++ b/src/moscot/base/solver.py @@ -14,7 +14,6 @@ Union, ) - from moscot._logging import logger from moscot._types import ArrayLike, Device_t, ProblemKind_t from moscot.base.output import BaseDiscreteSolverOutput From 02882826ea027976dcb524ac7d4c218516425fee Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Sep 2024 15:37:03 +0200 Subject: [PATCH 23/25] asses mypy errors From 1559c363655eec1d0464dfb10eee04b0d295e83e Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 12 Sep 2024 14:54:56 +0200 Subject: [PATCH 24/25] Revert "asses mypy errors" This reverts commit 2e87abea310c91839a10a5a1250ba0e5ba760b98. --- src/moscot/_types.py | 4 ++-- src/moscot/backends/ott/_utils.py | 2 +- src/moscot/backends/ott/output.py | 4 ++-- src/moscot/backends/utils.py | 3 ++- src/moscot/base/cost.py | 4 ++-- src/moscot/base/output.py | 2 +- src/moscot/base/problems/_mixins.py | 4 ++-- src/moscot/base/problems/_utils.py | 2 +- src/moscot/base/problems/problem.py | 2 +- src/moscot/costs/_costs.py | 6 +++--- src/moscot/plotting/_utils.py | 2 +- src/moscot/problems/generic/_generic.py | 4 ++-- src/moscot/problems/space/_mapping.py | 4 ++-- src/moscot/problems/space/_mixins.py | 6 +++--- src/moscot/problems/time/_lineage.py | 2 +- src/moscot/problems/time/_mixins.py | 6 +++--- src/moscot/utils/tagged_array.py | 2 +- 17 files changed, 30 insertions(+), 29 deletions(-) diff --git a/src/moscot/_types.py b/src/moscot/_types.py index c3f599d2d..1c60884a2 100644 --- a/src/moscot/_types.py +++ b/src/moscot/_types.py @@ -10,8 +10,8 @@ ArrayLike = NDArray[np.float64] except (ImportError, TypeError): - ArrayLike = np.ndarray - DTypeLike = np.dtype + ArrayLike = np.ndarray # type: ignore[misc] + DTypeLike = np.dtype # type: ignore[misc] ProblemKind_t = Literal["linear", "quadratic", "unknown"] Numeric_t = Union[int, float] # type of `time_key` arguments diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 42cc87af5..2cac53b30 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -108,7 +108,7 @@ def densify(arr: ArrayLike) -> jax.Array: dense :mod:`jax` array. """ if sp.issparse(arr): - arr = arr.toarray() + arr = arr.toarray() # type: ignore[attr-defined] elif isinstance(arr, jesp.BCOO): arr = arr.todense() return jnp.asarray(arr) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index c50850ec3..60d727faf 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -375,8 +375,8 @@ def project_to_transport_matrix( # type:ignore[override] The projected transport matrix. """ src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells) - push: Callable[[Any], Any] = self.push if condition is None else lambda x: self.push(x, condition) # type: ignore - pull: Callable[[Any], Any] = self.pull if condition is None else lambda x: self.pull(x, condition) # type: ignore + push = self.push if condition is None else lambda x: self.push(x, condition) + pull = self.pull if condition is None else lambda x: self.pull(x, condition) func, src_dist, tgt_dist = (push, src_cells, tgt_cells) if forward else (pull, tgt_cells, src_cells) return self._project_transport_matrix( src_dist=src_dist, diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index 988e05413..fde874c0f 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -42,7 +42,8 @@ def register_solver( return _REGISTRY.register(backend) # type: ignore[return-value] -@register_solver("ott") +# TODO(@MUCDK) fix mypy error +@register_solver("ott") # type: ignore[arg-type] def _( problem_kind: Literal["linear", "quadratic"], solver_name: Optional[Literal["GENOTLinSolver"]] = None, diff --git a/src/moscot/base/cost.py b/src/moscot/base/cost.py index 73adfe8c9..8bea310be 100644 --- a/src/moscot/base/cost.py +++ b/src/moscot/base/cost.py @@ -53,12 +53,12 @@ def __call__(self, *args: Any, **kwargs: Any) -> ArrayLike: """ cost = self._compute(*args, **kwargs) if np.any(np.isnan(cost)): - maxx = np.nanmax(cost) # type: ignore[var-annotated] + maxx = np.nanmax(cost) logger.warning( f"Cost matrix contains `{np.sum(np.isnan(cost))}` NaN values, " f"setting them to the maximum value `{maxx}`." ) - cost = np.nan_to_num(cost, nan=maxx) + cost = np.nan_to_num(cost, nan=maxx) # type: ignore[call-overload] if np.any(cost < 0): raise ValueError(f"Cost matrix contains `{np.sum(cost < 0)}` negative values.") return cost diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 3565c5790..9617b8729 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -354,7 +354,7 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102 @property def shape(self) -> tuple[int, int]: # noqa: D102 - return self.transport_matrix.shape + return self.transport_matrix.shape # type: ignore[return-value] def to( # noqa: D102 self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 8d27cd177..9c460eac2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -388,7 +388,7 @@ def _sample_from_tmap( account_for_unbalancedness: bool = False, interpolation_parameter: Optional[Numeric_t] = None, seed: Optional[int] = None, - ) -> tuple[Any, list[str]]: + ) -> tuple[list[Any], list[ArrayLike]]: rng = np.random.RandomState(seed) if account_for_unbalancedness and interpolation_parameter is None: raise ValueError("When accounting for unbalancedness, interpolation parameter must be provided.") @@ -453,7 +453,7 @@ def _sample_from_tmap( for i in range(len(rows_batch)) ] all_cols_sampled.extend(cols_sampled) - return rows, all_cols_sampled + return rows, all_cols_sampled # type: ignore[return-value] def _interpolate_transport( self: AnalysisMixinProtocol[K, B], diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 482f79b75..40bf0a99a 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -386,7 +386,7 @@ def perm_test_extractor(res: Sequence[Tuple[ArrayLike, ArrayLike]]) -> Tuple[Arr corr_bs = np.concatenate(corr_bs, axis=0) corr_ci_low, corr_ci_high = np.quantile(corr_bs, q=ql, axis=0), np.quantile(corr_bs, q=qh, axis=0) - return pvals, corr_ci_low, corr_ci_high + return pvals, corr_ci_low, corr_ci_high # type:ignore[return-value] if not (0 <= confidence_level <= 1): raise ValueError(f"Expected `confidence_level` to be in interval `[0, 1]`, found `{confidence_level}`.") diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 93a3e24d0..708b35d44 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -183,7 +183,7 @@ def _split_mass(arr: ArrayLike) -> ArrayLike: if start >= adata.n_obs: raise IndexError(f"Expected starting index to be smaller than `{adata.n_obs}`, found `{start}`.") data = np.zeros((adata.n_obs,), dtype=float) - data[range(start, min(start + offset, adata.n_obs))] = 1.0 # type: ignore[index] + data[range(start, min(start + offset, adata.n_obs))] = 1.0 else: raise TypeError(f"Unable to interpret subset of type `{type(subset)}`.") elif not hasattr(data, "shape"): diff --git a/src/moscot/costs/_costs.py b/src/moscot/costs/_costs.py index fd79b8a71..f41c0f7f2 100644 --- a/src/moscot/costs/_costs.py +++ b/src/moscot/costs/_costs.py @@ -138,7 +138,7 @@ def _scaled_hamming_dist(x: ArrayLike, y: ArrayLike) -> float: raise ValueError("No shared indices.") b2 = y[shared_indices] - differences: ArrayLike = b1 != b2 - double_scars: ArrayLike = differences & (b1 != 0) & (b2 != 0) + differences = b1 != b2 + double_scars = differences & (b1 != 0) & (b2 != 0) - return float(float(np.sum(differences)) + np.sum(double_scars)) / len(b1) + return (np.sum(differences) + np.sum(double_scars)) / len(b1) diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index 7047945b0..358ae5fa2 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -463,7 +463,7 @@ def _plot_scatter( _ = kwargs.pop("palette", None) if (time_points[i] == source and push) or (time_points[i] == target and not push): st = f"not in {time_points[i]}" - vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) # type: ignore[var-annotated] + vmin, vmax = np.nanmin(tmp[mask]), np.nanmax(tmp[mask]) column = pd.Series(tmp).fillna(st).astype("category") # TODO(michalk8): check if len(np.unique(column[mask.values].values)) > 2: diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index c5b885dcc..8d727d0d0 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -37,7 +37,7 @@ def set_quad_defaults(z: Optional[Union[str, Mapping[str, Any]]]) -> Dict[str, s raise TypeError("`x_attr` and `y_attr` must be of type `str` or `dict` if no callback is provided.") -class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): +class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] """Class for solving a :term:`linear problem`. Parameters @@ -264,7 +264,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): +class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ignore[misc] """Class for solving the :term:`GW ` or :term:`FGW ` problems. Parameters diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index dcb74d60a..abe440e28 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -65,8 +65,8 @@ def _create_problem( adata_tgt=self.adata_sc, src_obs_mask=src_mask, tgt_obs_mask=None, - src_var_mask=self.filtered_vars, - tgt_var_mask=self.filtered_vars, + src_var_mask=self.filtered_vars, # type: ignore[arg-type] + tgt_var_mask=self.filtered_vars, # type: ignore[arg-type] src_key=src, tgt_key=tgt, **kwargs, diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index a182999f1..d72c79c2b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -42,7 +42,7 @@ class SpatialAlignmentMixinProtocol(AnalysisMixinProtocol[K, B]): _spatial_key: Optional[str] batch_key: Optional[str] - def _subset_spatial( + def _subset_spatial( # type:ignore[empty-body] self: "SpatialAlignmentMixinProtocol[K, B]", k: K, spatial_key: str, @@ -780,13 +780,13 @@ def _compute_correspondence( def pdist(row_idx: ArrayLike, col_idx: float, feat: ArrayLike) -> Any: if len(row_idx) > 0: - return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() + return pairwise_distances(feat[row_idx, :], feat[[col_idx], :]).mean() # type: ignore[index] return np.nan # TODO(michalk8): vectorize using jax, this is just a for loop vpdist = np.vectorize(pdist, excluded=["feat"]) if sp.issparse(features): - features = features.toarray() + features = features.toarray() # type: ignore[attr-defined] feat_arr, index_arr, support_arr = [], [], [] for ind, i in enumerate(support): diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index b20ee5e11..809b7f8b9 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -23,7 +23,7 @@ __all__ = ["TemporalProblem", "LineageProblem"] -class TemporalProblem( +class TemporalProblem( # type: ignore[misc] TemporalMixin[Numeric_t, BirthDeathProblem], BirthDeathMixin, CompoundProblem[Numeric_t, BirthDeathProblem] ): """Class for analyzing time-series single cell data based on :cite:`schiebinger:19`. diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 3d565bdec..29a44aed0 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -587,7 +587,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.asarray(problem.solution.potentials[0]), + np.asarray(problem.solution.potentials[0]), # type: ignore[union-attr,index] index=problem.adata_src.obs_names, columns=cols, ) @@ -612,7 +612,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.array(problem.solution.potentials[1]), + np.array(problem.solution.potentials[1]), # type: ignore[union-attr,index] index=problem.adata_tgt.obs_names, columns=cols, ) @@ -664,7 +664,7 @@ def _get_data( else: raise ValueError(f"No data found for `{target}` time point.") - return ( + return ( # type:ignore[return-value] source_data, growth_rates_source, intermediate_data, diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 5a1ed5781..050573534 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -167,7 +167,7 @@ def shape(self) -> Tuple[int, int]: x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt) return x.shape[0], y.shape[0] - return self.data_src.shape + return self.data_src.shape # type: ignore[return-value] @property def is_cost_matrix(self) -> bool: From f92b6bc9a8f620efda60e8384bdcfcc34cde8358 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 12 Sep 2024 14:57:01 +0200 Subject: [PATCH 25/25] fix mypy errors --- src/moscot/__init__.py | 6 +++--- src/moscot/base/output.py | 2 +- src/moscot/utils/tagged_array.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/moscot/__init__.py b/src/moscot/__init__.py index 02925d781..dc69df60d 100644 --- a/src/moscot/__init__.py +++ b/src/moscot/__init__.py @@ -4,9 +4,9 @@ try: md = metadata.metadata(__name__) - __version__ = md.get("version", "") - __author__ = md.get("Author", "") - __maintainer__ = md.get("Maintainer-email", "") + __version__ = md.get("version", "") # type: ignore[attr-defined] + __author__ = md.get("Author", "") # type: ignore[attr-defined] + __maintainer__ = md.get("Maintainer-email", "") # type: ignore[attr-defined] except ImportError: md = None diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 9617b8729..3565c5790 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -354,7 +354,7 @@ def transport_matrix(self) -> ArrayLike: # noqa: D102 @property def shape(self) -> tuple[int, int]: # noqa: D102 - return self.transport_matrix.shape # type: ignore[return-value] + return self.transport_matrix.shape def to( # noqa: D102 self, device: Optional[Device_t] = None, dtype: Optional[DTypeLike] = None diff --git a/src/moscot/utils/tagged_array.py b/src/moscot/utils/tagged_array.py index 050573534..5a1ed5781 100644 --- a/src/moscot/utils/tagged_array.py +++ b/src/moscot/utils/tagged_array.py @@ -167,7 +167,7 @@ def shape(self) -> Tuple[int, int]: x, y = self.data_src, (self.data_src if self.data_tgt is None else self.data_tgt) return x.shape[0], y.shape[0] - return self.data_src.shape # type: ignore[return-value] + return self.data_src.shape @property def is_cost_matrix(self) -> bool: