From 3762d708e302390ad817ab062bff22f25f823a53 Mon Sep 17 00:00:00 2001 From: lautisilber Date: Thu, 30 May 2024 14:29:38 +0200 Subject: [PATCH 1/6] using custom __deepcopy__ methods to chose what to deep-copy and what to shallow-copy --- src/moscot/base/problems/_utils.py | 20 ++++++++++ src/moscot/base/problems/compound_problem.py | 26 ++++++++++++- src/moscot/base/problems/manager.py | 15 ++++++++ src/moscot/base/problems/problem.py | 40 ++++++++++++++++++++ 4 files changed, 100 insertions(+), 1 deletion(-) diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 40bf0a99a..25fd0f992 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -1,3 +1,4 @@ +import copy import functools import multiprocessing import threading @@ -16,6 +17,7 @@ Sequence, Tuple, Type, + TypeVar, Union, ) @@ -730,3 +732,21 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int: return multiprocessing.cpu_count() + 1 + n_cores return n_cores + + +C = TypeVar("C", bound=object) + + +def _copy_deep_shallow_helper( + original: C, memo: dict[int, Any], shallow_copy: tuple[str, ...] = (), dont_copy: tuple[str, ...] = () +) -> C: + cls = original.__class__ + result = cls.__new__(cls) + memo[id(original)] = result + for k, v in original.__dict__.items(): + if k in shallow_copy: + setattr(result, k, v) + memo[id(v)] = v + elif k not in dont_copy: + setattr(result, k, copy.deepcopy(v, memo)) + return result diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index eef27314b..8ec2b73eb 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -1,4 +1,5 @@ import abc +import copy import types from typing import ( TYPE_CHECKING, @@ -25,7 +26,11 @@ from moscot._logging import logger from moscot._types import ArrayLike, Policy_t, ProblemStage_t from moscot.base.output import BaseSolverOutput -from moscot.base.problems._utils import attributedispatch, require_prepare +from moscot.base.problems._utils import ( + _copy_deep_shallow_helper, + attributedispatch, + require_prepare, +) from moscot.base.problems.manager import ProblemManager from moscot.base.problems.problem import BaseProblem, OTProblem from moscot.utils.subset_policy import ( @@ -39,6 +44,8 @@ ) from moscot.utils.tagged_array import Tag, TaggedArray +# from moscot.base.problems._utils import _custom_copy + __all__ = ["BaseCompoundProblem", "CompoundProblem"] K = TypeVar("K", bound=Hashable) @@ -65,6 +72,23 @@ def __init__(self, adata: AnnData, **kwargs: Any): self._adata = adata self._problem_manager: Optional[ProblemManager[K, B]] = None + def __deepcopy__(self, memo) -> "BaseCompoundProblem[K, B]": + vars_to_shallow_copy = ("_adata",) + + return _copy_deep_shallow_helper(self, memo, vars_to_shallow_copy) + + def copy(self) -> "BaseCompoundProblem[K, B]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint + + Returns + ------- + Copy of Self + """ + return copy.deepcopy(self) + @abc.abstractmethod def _create_problem(self, src: K, tgt: K, src_mask: ArrayLike, tgt_mask: ArrayLike, **kwargs: Any) -> B: """Create an :term:`OT` subproblem. diff --git a/src/moscot/base/problems/manager.py b/src/moscot/base/problems/manager.py index 1994f27d8..23cd7ec25 100644 --- a/src/moscot/base/problems/manager.py +++ b/src/moscot/base/problems/manager.py @@ -1,4 +1,7 @@ import collections + +# from moscot.base.problems._utils import _custom_copy +import copy from typing import ( TYPE_CHECKING, Dict, @@ -40,6 +43,18 @@ def __init__(self, compound_problem: "BaseCompoundProblem[K, B]", policy: Subset self._policy = policy self._problems: Dict[Tuple[K, K], B] = {} + def copy(self) -> "ProblemManager[K, B]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint + + Returns + ------- + Copy of Self + """ + return copy.deepcopy(self) + def add_problem( self, key: Tuple[K, K], problem: B, *, overwrite: bool = False, verify_integrity: bool = True ) -> None: diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index b249a46f6..624248983 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -1,4 +1,5 @@ import abc +import copy import pathlib import types from typing import ( @@ -33,6 +34,7 @@ TimeScalesHeatKernel, _assert_columns_and_index_match, _assert_series_match, + _copy_deep_shallow_helper, require_solution, wrap_prepare, wrap_solve, @@ -69,6 +71,19 @@ def prepare(self, *args: Any, **kwargs: Any) -> "BaseProblem": - :attr:`problem_kind` - kind of the :term:`OT` problem. """ + @abc.abstractmethod + def copy(self) -> "BaseProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint + + Returns + ------- + Copy of Self + """ + pass + @abc.abstractmethod def solve(self, *args: Any, **kwargs: Any) -> "BaseProblem": """Solve the problem. @@ -199,6 +214,7 @@ def problem_kind(self) -> ProblemKind_t: return self._problem_kind +# from moscot.base.problems._utils import _custom_copy class OTProblem(BaseProblem): """Base class for all :term:`OT` problems. @@ -259,6 +275,30 @@ def __init__( self._time_scales_heat_kernel = TimeScalesHeatKernel(None, None, None) + def __deepcopy__(self, memo) -> "OTProblem": + vars_to_shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + return _copy_deep_shallow_helper(self, memo, vars_to_shallow_copy) + + def copy(self) -> "OTProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint + + Returns + ------- + Copy of Self + """ + return copy.deepcopy(self) + def _handle_linear(self, cost: CostFn_t = None, **kwargs: Any) -> TaggedArray: if "x_attr" not in kwargs or "y_attr" not in kwargs: kwargs.setdefault("tag", Tag.COST_MATRIX) From ba44389ee16ab62d86272ba3d868300daaa61764 Mon Sep 17 00:00:00 2001 From: lautisilber Date: Thu, 30 May 2024 18:07:32 +0200 Subject: [PATCH 2/6] added tests and shallow copy flag to TranslationProblem._adata_tgt --- src/moscot/base/problems/compound_problem.py | 2 ++ src/moscot/base/problems/manager.py | 2 ++ src/moscot/base/problems/problem.py | 2 ++ .../problems/cross_modality/_translation.py | 9 +++++ tests/problems/_utils.py | 30 ++++++++++++++++ tests/problems/base/test_compound_problem.py | 27 ++++++++++++++ tests/problems/base/test_general_problem.py | 34 ++++++++++++++++++ .../test_translation_problem.py | 32 +++++++++++++++++ tests/problems/generic/test_fgw_problem.py | 31 ++++++++++++++++ tests/problems/generic/test_gw_problem.py | 30 ++++++++++++++++ .../problems/generic/test_sinkhorn_problem.py | 24 +++++++++++++ .../problems/space/test_alignment_problem.py | 28 +++++++++++++++ tests/problems/space/test_mapping_problem.py | 22 ++++++++++++ .../test_spatio_temporal_problem.py | 26 ++++++++++++++ tests/problems/time/test_lineage_problem.py | 29 +++++++++++++++ .../time/test_temporal_base_problem.py | 36 +++++++++++++++++++ tests/problems/time/test_temporal_problem.py | 27 ++++++++++++++ 17 files changed, 391 insertions(+) create mode 100644 tests/problems/_utils.py diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 8ec2b73eb..96a48aeba 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -87,6 +87,8 @@ def copy(self) -> "BaseCompoundProblem[K, B]": ------- Copy of Self """ + if self.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") return copy.deepcopy(self) @abc.abstractmethod diff --git a/src/moscot/base/problems/manager.py b/src/moscot/base/problems/manager.py index 23cd7ec25..1cff50334 100644 --- a/src/moscot/base/problems/manager.py +++ b/src/moscot/base/problems/manager.py @@ -53,6 +53,8 @@ def copy(self) -> "ProblemManager[K, B]": ------- Copy of Self """ + if self._compound_problem.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") return copy.deepcopy(self) def add_problem( diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 624248983..181bd836c 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -297,6 +297,8 @@ def copy(self) -> "OTProblem": ------- Copy of Self """ + if self.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") return copy.deepcopy(self) def _handle_linear(self, cost: CostFn_t = None, **kwargs: Any) -> TaggedArray: diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 0f26857df..284b8181d 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -13,6 +13,7 @@ QuadInitializer_t, ScaleCost_t, ) +from moscot.base.problems._utils import _copy_deep_shallow_helper from moscot.base.problems.compound_problem import B, CompoundProblem, K from moscot.base.problems.problem import OTProblem from moscot.problems._utils import handle_cost, handle_joint_attr @@ -69,6 +70,14 @@ def _create_problem( **kwargs, ) + def __deepcopy__(self, memo) -> "TranslationProblem[K]": + vars_to_shallow_copy = ( + "_adata", + "_adata_tgt", + ) + + return _copy_deep_shallow_helper(self, memo, vars_to_shallow_copy) + def prepare( self, src_attr: Union[str, Mapping[str, Any]], diff --git a/tests/problems/_utils.py b/tests/problems/_utils.py new file mode 100644 index 000000000..095b6a697 --- /dev/null +++ b/tests/problems/_utils.py @@ -0,0 +1,30 @@ +from itertools import combinations + + +def _check_is_copy(o1: object, o2: object, shallow_copy: tuple[str, ...]) -> bool: + if type(o1) is not type(o2): + return False + + for k in o1.__dict__: + v1 = getattr(o1, k) + v2 = getattr(o2, k) + if type(v1) is not type(v2): + return False + if isinstance(v1, (str, int, bool, float)) or v1 is None: + # these basic types are treated differently in python and there's no point in comparing their ids + continue + if k in shallow_copy: + if id(v1) != id(v2): + return False + else: + if id(v1) == id(v2): + return False + + return True + + +def check_is_copy_multiple(os: tuple[object, ...], shallow_copy: tuple[str, ...]) -> bool: + if len(os) < 1: + return False + combs = combinations(os, 2) + return all(_check_is_copy(o1, o2, shallow_copy) for o1, o2 in combs) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 12a75a69f..895837fed 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -1,3 +1,4 @@ +import copy import os from typing import Any, Literal, Optional, Tuple @@ -15,6 +16,7 @@ from moscot.base.problems import CompoundProblem, OTProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL, Problem +from tests.problems._utils import check_is_copy_multiple class TestCompoundProblem: @@ -42,6 +44,31 @@ def y_callback( assert isinstance(adata_y, AnnData) return TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX) + def test_copy(self, adata_time: AnnData): + shallow_copy = ("_adata",) + + prepare_params = { + "xy": {"x_attr": "X", "y_attr": "X"}, + "key": "time", + "policy": "sequential", + } + solve_params = {"max_iterations": 2} + + prob = Problem(adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_sc_pipeline(self, adata_time: AnnData): expected_keys = [(0, 1), (1, 2)] problem = Problem(adata_time) diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index bdfecc7f6..beebc76d6 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Literal, Optional, Tuple import pytest @@ -16,9 +17,42 @@ from moscot.base.problems import OTProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL, Geom_t, MockSolverOutput +from tests.problems._utils import check_is_copy_multiple class TestOTProblem: + def test_copy_problem(self, adata_x: AnnData, adata_y: AnnData): + shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + prepare_params = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, + "x": {"attr": "X"}, + "y": {"attr": "X"}, + } + solve_params = {"epsilon": 5e-1, "alpha": 0.5} + + prob = OTProblem(adata_x, adata_y) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): prob = OTProblem(adata_x, adata_y) prob = prob.prepare( diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index f5199877a..f0fe218e8 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -1,3 +1,4 @@ +import copy from contextlib import nullcontext from typing import Any, Literal, Mapping, Optional, Tuple @@ -11,6 +12,7 @@ from moscot.backends.ott._utils import alpha_to_fused_penalty from moscot.base.output import BaseSolverOutput from moscot.problems.cross_modality import TranslationProblem +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -82,6 +84,36 @@ def test_prepare_external_star_policy( if joint_attr is not None: assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars) + def test_copy(self, adata_translation_split: Tuple[AnnData, AnnData]): + # shallow_copy = ("_adata",) + shallow_copy = ( + "_adata", + "_adata_tgt", + ) + + adata_src, adata_tgt = adata_translation_split + epsilon, alpha, rank, joint_attr = 1e-2, 0.9, -1, {"attr": "obsm", "key": "X_pca"} + src_attr = "emb_src" + tgt_attr = "emb_tgt" + + prepare_params = {"batch_key": "batch", "src_attr": src_attr, "tgt_attr": tgt_attr, "joint_attr": joint_attr} + solve_params = {"epsilon": epsilon, "alpha": alpha, "rank": rank} + + prob = TranslationProblem(adata_src, adata_tgt) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer", "joint_attr", "expect_fail"), [ diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 52d040d51..793af8cad 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Literal, Mapping import pytest @@ -25,6 +26,7 @@ from moscot.base.problems import OTProblem from moscot.problems.generic import FGWProblem from tests._utils import _assert_marginals_set +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -83,6 +85,35 @@ def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): for key in problem: _assert_marginals_set(adata_time, problem, key, marginal_keys) + def test_copy(self, adata_space_rotate: AnnData): + shallow_copy = ("_adata",) + + eps = 0.5 + adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() + prepare_params = { + "key": "batch", + "policy": "sequential", + "joint_attr": "X_pca", + "x_attr": {"attr": "obsm", "key": "spatial"}, + "y_attr": {"attr": "obsm", "key": "spatial"}, + } + solve_params = {"alpha": 0.5, "epsilon": eps} + + prob = FGWProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_solve_balanced(self, adata_space_rotate: AnnData): eps = 0.5 adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 5b13c8f2f..e788b7fca 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Literal, Mapping import pytest @@ -24,6 +25,7 @@ from moscot.base.problems import OTProblem from moscot.problems.generic import GWProblem from tests._utils import _assert_marginals_set +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( geometry_args, gw_args_1, @@ -68,6 +70,34 @@ def test_prepare(self, adata_space_rotate: AnnData, policy): assert key in expected_keys[policy] assert isinstance(problem[key], OTProblem) + def test_copy(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] + shallow_copy = ("_adata",) + + eps = 0.5 + + prepare_params = { + "key": "batch", + "policy": "sequential", + "x_attr": {"attr": "obsm", "key": "spatial"}, + "y_attr": {"attr": "obsm", "key": "spatial"}, + } + solve_params = {"epsilon": eps} + + prob = GWProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_solve_balanced(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] eps = 0.5 expected_keys = [("0", "1"), ("1", "2")] diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index cc8b8af76..f7a9b9924 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Literal, Mapping import pytest @@ -23,6 +24,7 @@ from moscot.base.problems import OTProblem from moscot.problems.generic import SinkhornProblem from tests._utils import _assert_marginals_set +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( geometry_args, lin_prob_args, @@ -56,6 +58,28 @@ def test_prepare(self, adata_time: AnnData, policy, marginal_keys): _assert_marginals_set(adata_time, problem, key, marginal_keys) + def test_copy(self, adata_time: AnnData, marginal_keys): + shallow_copy = ("_adata",) + + eps = 0.5 + prepare_params = {"key": "time", "a": marginal_keys[0], "b": marginal_keys[1]} + solve_params = {"epsilon": eps} + + prob = SinkhornProblem(adata=adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_solve_balanced(self, adata_time: AnnData, marginal_keys): eps = 0.5 expected_keys = [(0, 1), (1, 2)] diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index d6a78c063..41d59824f 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Any, Literal, Mapping, Optional @@ -14,6 +15,7 @@ from moscot.backends.ott._utils import alpha_to_fused_penalty from moscot.problems.space import AlignmentProblem from moscot.utils.tagged_array import Tag, TaggedArray +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -75,6 +77,32 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): assert ref == reference assert isinstance(ap[prob_key], ap._base_problem_type) + def test_copy(self, adata_space_rotate: AnnData): + shallow_copy = ("_adata",) + + tau_a, tau_b = [0.8, 1] + marg_a = "a" + marg_b = "b" + adata_space_rotate.obs[marg_a] = adata_space_rotate.obs[marg_b] = np.ones(300) + + prepare_params = {"batch_key": "batch", "a": marg_a, "b": marg_b} + solve_params = {"tau_a": tau_a, "tau_b": tau_b} + + prob = AlignmentProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 903c02611..51fbaa75b 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -16,6 +16,7 @@ from moscot.problems.space import MappingProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import _adata_spatial_split +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -94,6 +95,27 @@ def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List assert prob.x.data_src.shape == (n_obs, x_n_var) assert prob.y.data_src.shape == (n_obs, y_n_var) + def test_copy(self, adata_mapping: AnnData): + shallow_copy = ("_adata",) + + prepare_params = {"batch_key": "batch", "sc_attr": {"attr": "X"}} + adataref, adatasp = _adata_spatial_split(adata_mapping) + + prob = MappingProblem(adataref, adatasp) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + # with pytest.raises(copy.Error): + # prob = prob.solve(**solve_params) # type: ignore + # _ = prob.copy() + @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index a22d5cd35..fdf239c84 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, List, Mapping import pytest @@ -13,6 +14,7 @@ from moscot.base.problems import BirthDeathProblem from moscot.problems.spatiotemporal import SpatioTemporalProblem from tests._utils import ATOL, RTOL +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -48,6 +50,30 @@ def test_prepare(self, adata_spatio_temporal: AnnData): assert key in expected_keys assert isinstance(problem[key], BirthDeathProblem) + def test_copy(self, adata_spatio_temporal: AnnData): + shallow_copy = ("_adata",) + + eps = 1 + alpha = 0.5 + + prepare_params = {"time_key": "time", "spatial_key": "spatial"} + solve_params = {"alpha": alpha, "epsilon": eps} + + prob = SpatioTemporalProblem(adata=adata_spatio_temporal) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_solve_balanced(self, adata_spatio_temporal: AnnData): eps = 1 alpha = 0.5 diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index fed68cf6f..cb68e6b9c 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, List, Mapping import pytest @@ -12,6 +13,7 @@ from moscot.base.problems import BirthDeathProblem from moscot.problems.time import LineageProblem from tests._utils import ATOL, RTOL +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -46,6 +48,33 @@ def test_prepare(self, adata_time_barcodes: AnnData): assert key in expected_keys assert isinstance(problem[key], BirthDeathProblem) + def test_copy(self, adata_time_barcodes: AnnData): + shallow_copy = ("_adata",) + + eps, key = 0.5, (0, 1) + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() + prepare_params = { + "time_key": "time", + "policy": "sequential", + "lineage_attr": {"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + } + solve_params = {"epsilon": eps} + + prob = LineageProblem(adata=adata_time_barcodes) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + def test_solve_balanced(self, adata_time_barcodes: AnnData): eps, key = 0.5, (0, 1) adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index 24dee7e42..c92045fe4 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -7,6 +7,7 @@ from anndata import AnnData from moscot.base.problems import BirthDeathProblem +from tests.problems._utils import check_is_copy_multiple # TODO(@MUCDK) put file in different folder according to moscot.problems structure @@ -32,6 +33,41 @@ def test_initialization_pipeline(self, adata_time_marginal_estimations: AnnData) assert isinstance(prob.a, np.ndarray) assert isinstance(prob.b, np.ndarray) + def test_copy(self, adata_time_marginal_estimations: AnnData): + t1, t2 = 0, 1 + adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] + adata_y = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t2] + + shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + prepare_params = { + "xy": {}, + "x": {"attr": "X"}, + "y": {"attr": "X"}, + "a": True, + "b": True, + "proliferation_key": "proliferation", + "apoptosis_key": "apoptosis", + } + + prob = BirthDeathProblem(adata_x, adata_y, src_key=t1, tgt_key=t2) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + # TODO(MUCDK): break this test @pytest.mark.fast() @pytest.mark.parametrize( diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 7858eb613..9ac30d615 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, List, Mapping, Optional import pytest @@ -18,6 +19,7 @@ from moscot.problems.time import TemporalProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( geometry_args, lin_prob_args, @@ -52,6 +54,31 @@ def test_prepare(self, adata_time: AnnData): assert key in expected_keys assert isinstance(problem[key], BirthDeathProblem) + @pytest.mark.parametrize("callback", ["local-pca", None]) + def test_copy(self, adata_time: AnnData, callback: Optional[str]): + shallow_copy = ("_adata",) + + eps = 0.5 + joint_attr = None if callback else "X_pca" + + prepare_params = {"time_key": "time", "cost": "cosine", "xy_callback": callback, "joint_attr": joint_attr} + solve_params = {"epsilon": eps} + + prob = TemporalProblem(adata=adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() + @pytest.mark.parametrize("callback", ["local-pca", None]) def test_solve_balanced(self, adata_time: AnnData, callback: Optional[str]): eps = 0.5 From 38239a221c11b8f1b4d473425dad1c63fca13c38 Mon Sep 17 00:00:00 2001 From: lautisilber Date: Thu, 30 May 2024 18:27:14 +0200 Subject: [PATCH 3/6] cleaning up --- src/moscot/base/problems/_utils.py | 2 +- src/moscot/base/problems/compound_problem.py | 8 +-- src/moscot/base/problems/manager.py | 4 +- src/moscot/base/problems/problem.py | 9 ++- .../problems/cross_modality/_translation.py | 4 +- tests/problems/_utils.py | 2 +- tests/problems/base/test_compound_problem.py | 50 ++++++------- tests/problems/base/test_general_problem.py | 64 ++++++++--------- .../test_translation_problem.py | 60 ++++++++-------- tests/problems/generic/test_fgw_problem.py | 58 +++++++-------- tests/problems/generic/test_gw_problem.py | 56 +++++++-------- .../problems/generic/test_sinkhorn_problem.py | 44 ++++++------ .../problems/space/test_alignment_problem.py | 52 +++++++------- tests/problems/space/test_mapping_problem.py | 38 +++++----- .../test_spatio_temporal_problem.py | 48 ++++++------- tests/problems/time/test_lineage_problem.py | 54 +++++++------- .../time/test_temporal_base_problem.py | 70 +++++++++---------- tests/problems/time/test_temporal_problem.py | 47 ++++++------- 18 files changed, 329 insertions(+), 341 deletions(-) diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 25fd0f992..a346fa569 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -737,7 +737,7 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int: C = TypeVar("C", bound=object) -def _copy_deep_shallow_helper( +def _copy_depth_helper( original: C, memo: dict[int, Any], shallow_copy: tuple[str, ...] = (), dont_copy: tuple[str, ...] = () ) -> C: cls = original.__class__ diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 96a48aeba..e1ed20d1a 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -27,7 +27,7 @@ from moscot._types import ArrayLike, Policy_t, ProblemStage_t from moscot.base.output import BaseSolverOutput from moscot.base.problems._utils import ( - _copy_deep_shallow_helper, + _copy_depth_helper, attributedispatch, require_prepare, ) @@ -44,8 +44,6 @@ ) from moscot.utils.tagged_array import Tag, TaggedArray -# from moscot.base.problems._utils import _custom_copy - __all__ = ["BaseCompoundProblem", "CompoundProblem"] K = TypeVar("K", bound=Hashable) @@ -75,13 +73,13 @@ def __init__(self, adata: AnnData, **kwargs: Any): def __deepcopy__(self, memo) -> "BaseCompoundProblem[K, B]": vars_to_shallow_copy = ("_adata",) - return _copy_deep_shallow_helper(self, memo, vars_to_shallow_copy) + return _copy_depth_helper(self, memo, vars_to_shallow_copy) def copy(self) -> "BaseCompoundProblem[K, B]": """Create a copy of self. It deep-copies everything except for the data which is shallow-copied (by reference) - to improve the memory footprint + to improve the memory footprint. Returns ------- diff --git a/src/moscot/base/problems/manager.py b/src/moscot/base/problems/manager.py index 1cff50334..5481a16ff 100644 --- a/src/moscot/base/problems/manager.py +++ b/src/moscot/base/problems/manager.py @@ -1,6 +1,4 @@ import collections - -# from moscot.base.problems._utils import _custom_copy import copy from typing import ( TYPE_CHECKING, @@ -47,7 +45,7 @@ def copy(self) -> "ProblemManager[K, B]": """Create a copy of self. It deep-copies everything except for the data which is shallow-copied (by reference) - to improve the memory footprint + to improve the memory footprint. Returns ------- diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 181bd836c..254080b24 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -34,7 +34,7 @@ TimeScalesHeatKernel, _assert_columns_and_index_match, _assert_series_match, - _copy_deep_shallow_helper, + _copy_depth_helper, require_solution, wrap_prepare, wrap_solve, @@ -76,7 +76,7 @@ def copy(self) -> "BaseProblem": """Create a copy of self. It deep-copies everything except for the data which is shallow-copied (by reference) - to improve the memory footprint + to improve the memory footprint. Returns ------- @@ -214,7 +214,6 @@ def problem_kind(self) -> ProblemKind_t: return self._problem_kind -# from moscot.base.problems._utils import _custom_copy class OTProblem(BaseProblem): """Base class for all :term:`OT` problems. @@ -285,13 +284,13 @@ def __deepcopy__(self, memo) -> "OTProblem": "_tgt_var_mask", ) - return _copy_deep_shallow_helper(self, memo, vars_to_shallow_copy) + return _copy_depth_helper(self, memo, vars_to_shallow_copy) def copy(self) -> "OTProblem": """Create a copy of self. It deep-copies everything except for the data which is shallow-copied (by reference) - to improve the memory footprint + to improve the memory footprint. Returns ------- diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 284b8181d..8da0fe869 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -13,7 +13,7 @@ QuadInitializer_t, ScaleCost_t, ) -from moscot.base.problems._utils import _copy_deep_shallow_helper +from moscot.base.problems._utils import _copy_depth_helper from moscot.base.problems.compound_problem import B, CompoundProblem, K from moscot.base.problems.problem import OTProblem from moscot.problems._utils import handle_cost, handle_joint_attr @@ -76,7 +76,7 @@ def __deepcopy__(self, memo) -> "TranslationProblem[K]": "_adata_tgt", ) - return _copy_deep_shallow_helper(self, memo, vars_to_shallow_copy) + return _copy_depth_helper(self, memo, vars_to_shallow_copy) def prepare( self, diff --git a/tests/problems/_utils.py b/tests/problems/_utils.py index 095b6a697..343080b34 100644 --- a/tests/problems/_utils.py +++ b/tests/problems/_utils.py @@ -10,8 +10,8 @@ def _check_is_copy(o1: object, o2: object, shallow_copy: tuple[str, ...]) -> boo v2 = getattr(o2, k) if type(v1) is not type(v2): return False + # these basic types are treated differently in python and there's no point in comparing their ids if isinstance(v1, (str, int, bool, float)) or v1 is None: - # these basic types are treated differently in python and there's no point in comparing their ids continue if k in shallow_copy: if id(v1) != id(v2): diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 895837fed..8d6ae21b5 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -44,31 +44,6 @@ def y_callback( assert isinstance(adata_y, AnnData) return TaggedArray(euclidean_distances(adata.X, adata_y.X), tag=Tag.COST_MATRIX) - def test_copy(self, adata_time: AnnData): - shallow_copy = ("_adata",) - - prepare_params = { - "xy": {"x_attr": "X", "y_attr": "X"}, - "key": "time", - "policy": "sequential", - } - solve_params = {"max_iterations": 2} - - prob = Problem(adata_time) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_sc_pipeline(self, adata_time: AnnData): expected_keys = [(0, 1), (1, 2)] problem = Problem(adata_time) @@ -313,3 +288,28 @@ def test_save_load_solved(self, adata_time: AnnData): p = Problem.load(file) assert isinstance(p, Problem) + + def test_copy(self, adata_time: AnnData): + shallow_copy = ("_adata",) + + prepare_params = { + "xy": {"x_attr": "X", "y_attr": "X"}, + "key": "time", + "policy": "sequential", + } + solve_params = {"max_iterations": 2} + + prob = Problem(adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index beebc76d6..c90f65240 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -21,38 +21,6 @@ class TestOTProblem: - def test_copy_problem(self, adata_x: AnnData, adata_y: AnnData): - shallow_copy = ( - "_adata_src", - "_adata_tgt", - "_src_obs_mask", - "_tgt_obs_mask", - "_src_var_mask", - "_tgt_var_mask", - ) - - prepare_params = { - "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, - "x": {"attr": "X"}, - "y": {"attr": "X"}, - } - solve_params = {"epsilon": 5e-1, "alpha": 0.5} - - prob = OTProblem(adata_x, adata_y) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): prob = OTProblem(adata_x, adata_y) prob = prob.prepare( @@ -380,3 +348,35 @@ def test_set_graph_xy_test_t(self, adata_x: AnnData, adata_y: AnnData, t: float) assert pushed_0.shape == pushed_1.shape assert np.all(np.abs(pushed_0 - pushed_1).sum() > np.abs(pushed_2 - pushed_1).sum()) assert np.all(np.abs(pushed_0 - pushed_2).sum() > np.abs(pushed_1 - pushed_2).sum()) + + def test_copy(self, adata_x: AnnData, adata_y: AnnData): + shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + prepare_params = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, + "x": {"attr": "X"}, + "y": {"attr": "X"}, + } + solve_params = {"epsilon": 5e-1, "alpha": 0.5} + + prob = OTProblem(adata_x, adata_y) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index f0fe218e8..c7c5a76f2 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -84,36 +84,6 @@ def test_prepare_external_star_policy( if joint_attr is not None: assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars) - def test_copy(self, adata_translation_split: Tuple[AnnData, AnnData]): - # shallow_copy = ("_adata",) - shallow_copy = ( - "_adata", - "_adata_tgt", - ) - - adata_src, adata_tgt = adata_translation_split - epsilon, alpha, rank, joint_attr = 1e-2, 0.9, -1, {"attr": "obsm", "key": "X_pca"} - src_attr = "emb_src" - tgt_attr = "emb_tgt" - - prepare_params = {"batch_key": "batch", "src_attr": src_attr, "tgt_attr": tgt_attr, "joint_attr": joint_attr} - solve_params = {"epsilon": epsilon, "alpha": alpha, "rank": rank} - - prob = TranslationProblem(adata_src, adata_tgt) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer", "joint_attr", "expect_fail"), [ @@ -212,3 +182,33 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], geom = quad_prob.geom_xy for arg, val in pointcloud_args.items(): assert getattr(geom, val) == args_to_check[arg], arg + + def test_copy(self, adata_translation_split: Tuple[AnnData, AnnData]): + # shallow_copy = ("_adata",) + shallow_copy = ( + "_adata", + "_adata_tgt", + ) + + adata_src, adata_tgt = adata_translation_split + epsilon, alpha, rank, joint_attr = 1e-2, 0.9, -1, {"attr": "obsm", "key": "X_pca"} + src_attr = "emb_src" + tgt_attr = "emb_tgt" + + prepare_params = {"batch_key": "batch", "src_attr": src_attr, "tgt_attr": tgt_attr, "joint_attr": joint_attr} + solve_params = {"epsilon": epsilon, "alpha": alpha, "rank": rank} + + prob = TranslationProblem(adata_src, adata_tgt) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 793af8cad..16efcba51 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -85,35 +85,6 @@ def test_prepare_marginals(self, adata_time: AnnData, marginal_keys): for key in problem: _assert_marginals_set(adata_time, problem, key, marginal_keys) - def test_copy(self, adata_space_rotate: AnnData): - shallow_copy = ("_adata",) - - eps = 0.5 - adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() - prepare_params = { - "key": "batch", - "policy": "sequential", - "joint_attr": "X_pca", - "x_attr": {"attr": "obsm", "key": "spatial"}, - "y_attr": {"attr": "obsm", "key": "spatial"}, - } - solve_params = {"alpha": 0.5, "epsilon": eps} - - prob = FGWProblem(adata=adata_space_rotate) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_solve_balanced(self, adata_space_rotate: AnnData): eps = 0.5 adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() @@ -426,3 +397,32 @@ def test_passing_ott_kwargs_quadratic(self, adata_space_rotate: AnnData, warm_st assert solver.warm_start == warm_start assert solver.store_inner_errors == inner_errors + + def test_copy(self, adata_space_rotate: AnnData): + shallow_copy = ("_adata",) + + eps = 0.5 + adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() + prepare_params = { + "key": "batch", + "policy": "sequential", + "joint_attr": "X_pca", + "x_attr": {"attr": "obsm", "key": "spatial"}, + "y_attr": {"attr": "obsm", "key": "spatial"}, + } + solve_params = {"alpha": 0.5, "epsilon": eps} + + prob = FGWProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index e788b7fca..99dad637e 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -70,34 +70,6 @@ def test_prepare(self, adata_space_rotate: AnnData, policy): assert key in expected_keys[policy] assert isinstance(problem[key], OTProblem) - def test_copy(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] - shallow_copy = ("_adata",) - - eps = 0.5 - - prepare_params = { - "key": "batch", - "policy": "sequential", - "x_attr": {"attr": "obsm", "key": "spatial"}, - "y_attr": {"attr": "obsm", "key": "spatial"}, - } - solve_params = {"epsilon": eps} - - prob = GWProblem(adata=adata_space_rotate) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_solve_balanced(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] eps = 0.5 expected_keys = [("0", "1"), ("1", "2")] @@ -389,3 +361,31 @@ def test_passing_ott_kwargs_quadratic(self, adata_space_rotate: AnnData, warm_st assert solver.warm_start == warm_start assert solver.store_inner_errors == inner_errors + + def test_copy(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] + shallow_copy = ("_adata",) + + eps = 0.5 + + prepare_params = { + "key": "batch", + "policy": "sequential", + "x_attr": {"attr": "obsm", "key": "spatial"}, + "y_attr": {"attr": "obsm", "key": "spatial"}, + } + solve_params = {"epsilon": eps} + + prob = GWProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index f7a9b9924..660340c6d 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -58,28 +58,6 @@ def test_prepare(self, adata_time: AnnData, policy, marginal_keys): _assert_marginals_set(adata_time, problem, key, marginal_keys) - def test_copy(self, adata_time: AnnData, marginal_keys): - shallow_copy = ("_adata",) - - eps = 0.5 - prepare_params = {"key": "time", "a": marginal_keys[0], "b": marginal_keys[1]} - solve_params = {"epsilon": eps} - - prob = SinkhornProblem(adata=adata_time) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_solve_balanced(self, adata_time: AnnData, marginal_keys): eps = 0.5 expected_keys = [(0, 1), (1, 2)] @@ -254,3 +232,25 @@ def test_passing_ott_kwargs(self, adata_time: AnnData, memory: int, refresh: int recenter_potentials = problem[0, 1].solver.solver.recenter_potentials assert recenter_potentials == recenter + + def test_copy(self, adata_time: AnnData, marginal_keys): + shallow_copy = ("_adata",) + + eps = 0.5 + prepare_params = {"key": "time", "a": marginal_keys[0], "b": marginal_keys[1]} + solve_params = {"epsilon": eps} + + prob = SinkhornProblem(adata=adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 41d59824f..9e21a0867 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -77,32 +77,6 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): assert ref == reference assert isinstance(ap[prob_key], ap._base_problem_type) - def test_copy(self, adata_space_rotate: AnnData): - shallow_copy = ("_adata",) - - tau_a, tau_b = [0.8, 1] - marg_a = "a" - marg_b = "b" - adata_space_rotate.obs[marg_a] = adata_space_rotate.obs[marg_b] = np.ones(300) - - prepare_params = {"batch_key": "batch", "a": marg_a, "b": marg_b} - solve_params = {"tau_a": tau_a, "tau_b": tau_b} - - prob = AlignmentProblem(adata=adata_space_rotate) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), @@ -254,3 +228,29 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_space_rotate: AnnData): + shallow_copy = ("_adata",) + + tau_a, tau_b = [0.8, 1] + marg_a = "a" + marg_b = "b" + adata_space_rotate.obs[marg_a] = adata_space_rotate.obs[marg_b] = np.ones(300) + + prepare_params = {"batch_key": "batch", "a": marg_a, "b": marg_b} + solve_params = {"tau_a": tau_a, "tau_b": tau_b} + + prob = AlignmentProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 51fbaa75b..144437480 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -95,27 +95,6 @@ def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List assert prob.x.data_src.shape == (n_obs, x_n_var) assert prob.y.data_src.shape == (n_obs, y_n_var) - def test_copy(self, adata_mapping: AnnData): - shallow_copy = ("_adata",) - - prepare_params = {"batch_key": "batch", "sc_attr": {"attr": "X"}} - adataref, adatasp = _adata_spatial_split(adata_mapping) - - prob = MappingProblem(adataref, adatasp) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - # with pytest.raises(copy.Error): - # prob = prob.solve(**solve_params) # type: ignore - # _ = prob.copy() - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), @@ -288,3 +267,20 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_mapping: AnnData): + shallow_copy = ("_adata",) + + prepare_params = {"batch_key": "batch", "sc_attr": {"attr": "X"}} + adataref, adatasp = _adata_spatial_split(adata_mapping) + + prob = MappingProblem(adataref, adatasp) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index fdf239c84..5dd19b9e0 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -50,30 +50,6 @@ def test_prepare(self, adata_spatio_temporal: AnnData): assert key in expected_keys assert isinstance(problem[key], BirthDeathProblem) - def test_copy(self, adata_spatio_temporal: AnnData): - shallow_copy = ("_adata",) - - eps = 1 - alpha = 0.5 - - prepare_params = {"time_key": "time", "spatial_key": "spatial"} - solve_params = {"alpha": alpha, "epsilon": eps} - - prob = SpatioTemporalProblem(adata=adata_spatio_temporal) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_solve_balanced(self, adata_spatio_temporal: AnnData): eps = 1 alpha = 0.5 @@ -260,3 +236,27 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_spatio_temporal: AnnData): + shallow_copy = ("_adata",) + + eps = 1 + alpha = 0.5 + + prepare_params = {"time_key": "time", "spatial_key": "spatial"} + solve_params = {"alpha": alpha, "epsilon": eps} + + prob = SpatioTemporalProblem(adata=adata_spatio_temporal) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index cb68e6b9c..04967d1d6 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -48,33 +48,6 @@ def test_prepare(self, adata_time_barcodes: AnnData): assert key in expected_keys assert isinstance(problem[key], BirthDeathProblem) - def test_copy(self, adata_time_barcodes: AnnData): - shallow_copy = ("_adata",) - - eps, key = 0.5, (0, 1) - adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() - prepare_params = { - "time_key": "time", - "policy": "sequential", - "lineage_attr": {"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, - } - solve_params = {"epsilon": eps} - - prob = LineageProblem(adata=adata_time_barcodes) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - def test_solve_balanced(self, adata_time_barcodes: AnnData): eps, key = 0.5, (0, 1) adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() @@ -298,3 +271,30 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_time_barcodes: AnnData): + shallow_copy = ("_adata",) + + eps, key = 0.5, (0, 1) + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() + prepare_params = { + "time_key": "time", + "policy": "sequential", + "lineage_attr": {"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + } + solve_params = {"epsilon": eps} + + prob = LineageProblem(adata=adata_time_barcodes) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index c92045fe4..6337c32ac 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -33,41 +33,6 @@ def test_initialization_pipeline(self, adata_time_marginal_estimations: AnnData) assert isinstance(prob.a, np.ndarray) assert isinstance(prob.b, np.ndarray) - def test_copy(self, adata_time_marginal_estimations: AnnData): - t1, t2 = 0, 1 - adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] - adata_y = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t2] - - shallow_copy = ( - "_adata_src", - "_adata_tgt", - "_src_obs_mask", - "_tgt_obs_mask", - "_src_var_mask", - "_tgt_var_mask", - ) - - prepare_params = { - "xy": {}, - "x": {"attr": "X"}, - "y": {"attr": "X"}, - "a": True, - "b": True, - "proliferation_key": "proliferation", - "apoptosis_key": "apoptosis", - } - - prob = BirthDeathProblem(adata_x, adata_y, src_key=t1, tgt_key=t2) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - # TODO(MUCDK): break this test @pytest.mark.fast() @pytest.mark.parametrize( @@ -171,3 +136,38 @@ def test_marginal_kwargs(self, adata_time_marginal_estimations: AnnData, margina assert not np.allclose(gr1, gr2) else: assert np.allclose(gr1, gr2) + + def test_copy(self, adata_time_marginal_estimations: AnnData): + t1, t2 = 0, 1 + adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] + adata_y = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t2] + + shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + prepare_params = { + "xy": {}, + "x": {"attr": "X"}, + "y": {"attr": "X"}, + "a": True, + "b": True, + "proliferation_key": "proliferation", + "apoptosis_key": "apoptosis", + } + + prob = BirthDeathProblem(adata_x, adata_y, src_key=t1, tgt_key=t2) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 9ac30d615..0632feaab 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -54,31 +54,6 @@ def test_prepare(self, adata_time: AnnData): assert key in expected_keys assert isinstance(problem[key], BirthDeathProblem) - @pytest.mark.parametrize("callback", ["local-pca", None]) - def test_copy(self, adata_time: AnnData, callback: Optional[str]): - shallow_copy = ("_adata",) - - eps = 0.5 - joint_attr = None if callback else "X_pca" - - prepare_params = {"time_key": "time", "cost": "cosine", "xy_callback": callback, "joint_attr": joint_attr} - solve_params = {"epsilon": eps} - - prob = TemporalProblem(adata=adata_time) - prob_copy_1 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) - - prob = prob.prepare(**prepare_params) # type: ignore - prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore - prob_copy_2 = prob.copy() - - assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) - - prob = prob.solve(**solve_params) # type: ignore - with pytest.raises(copy.Error): - _ = prob.copy() - @pytest.mark.parametrize("callback", ["local-pca", None]) def test_solve_balanced(self, adata_time: AnnData, callback: Optional[str]): eps = 0.5 @@ -497,3 +472,25 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A assert type(el) == type(args_to_check[arg]) # noqa: E721 else: assert el == args_to_check[arg] + + def test_copy(self, adata_time: AnnData): + shallow_copy = ("_adata",) + + eps = 0.5 + prepare_params = {"time_key": "time", "cost": "cosine", "xy_callback": None, "joint_attr": "X_pca"} + solve_params = {"epsilon": eps} + + prob = TemporalProblem(adata=adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() From c58f1201647e5c0681c20d5e9b163a21ca90888e Mon Sep 17 00:00:00 2001 From: lautisilber Date: Thu, 30 May 2024 19:13:02 +0200 Subject: [PATCH 4/6] fixed bug with BirthDeathProblem and fixed type hints --- src/moscot/base/problems/birth_death.py | 15 ++++++++++++ .../problems/cross_modality/_translation.py | 12 ++++++++++ src/moscot/problems/generic/_generic.py | 9 +++++++ src/moscot/problems/space/_alignment.py | 12 ++++++++++ src/moscot/problems/space/_mapping.py | 12 ++++++++++ .../spatiotemporal/_spatio_temporal.py | 12 ++++++++++ src/moscot/problems/time/_lineage.py | 24 +++++++++++++++++++ 7 files changed, 96 insertions(+) diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index 6f6004e6c..aeb41771c 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -1,3 +1,4 @@ +import copy from typing import ( TYPE_CHECKING, Any, @@ -159,6 +160,20 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem): Keyword arguments for :class:`~moscot.base.problems.OTProblem`. """ # noqa: D205 + def copy(self) -> "BirthDeathProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + if self.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") + return copy.deepcopy(self) + def estimate_marginals( self, # type: BirthDeathProblemProtocol adata: AnnData, diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 8da0fe869..2ac6d6d4c 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -78,6 +78,18 @@ def __deepcopy__(self, memo) -> "TranslationProblem[K]": return _copy_depth_helper(self, memo, vars_to_shallow_copy) + def copy(self) -> "TranslationProblem[K]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, src_attr: Union[str, Mapping[str, Any]], diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 1466aa998..bc1791079 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -44,6 +44,9 @@ class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # typ def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "SinkhornProblem[K, B]": + return super().copy() # type: ignore + def prepare( self, key: str, @@ -260,6 +263,9 @@ class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ign def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "GWProblem[K, B]": + return super().copy() # type: ignore + def prepare( self, key: str, @@ -477,6 +483,9 @@ class FGWProblem(GWProblem[K, B]): Keyword arguments for :class:`~moscot.base.problems.CompoundProblem`. """ + def copy(self) -> "FGWProblem[K, B]": + return super().copy() # type: ignore + def prepare( self, key: str, diff --git a/src/moscot/problems/space/_alignment.py b/src/moscot/problems/space/_alignment.py index f5a398d23..7b2a2edfe 100644 --- a/src/moscot/problems/space/_alignment.py +++ b/src/moscot/problems/space/_alignment.py @@ -34,6 +34,18 @@ class AlignmentProblem(SpatialAlignmentMixin[K, B], CompoundProblem[K, B]): def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "AlignmentProblem[K, B]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, batch_key: str, diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index f3243344d..ee3b10929 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -70,6 +70,18 @@ def _create_problem( **kwargs, ) + def copy(self) -> "MappingProblem[K]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, sc_attr: Union[str, Mapping[str, Any]], diff --git a/src/moscot/problems/spatiotemporal/_spatio_temporal.py b/src/moscot/problems/spatiotemporal/_spatio_temporal.py index 6a2e0424f..5b4c381cb 100644 --- a/src/moscot/problems/spatiotemporal/_spatio_temporal.py +++ b/src/moscot/problems/spatiotemporal/_spatio_temporal.py @@ -42,6 +42,18 @@ class SpatioTemporalProblem( # type: ignore[misc] def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "SpatioTemporalProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, time_key: str, diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index fea290f46..c03cbf5d7 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -39,6 +39,18 @@ class TemporalProblem( # type: ignore[misc] def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "TemporalProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, time_key: str, @@ -267,6 +279,18 @@ class LineageProblem(TemporalProblem): Keyword arguments for :class:`~moscot.problems.time.TemporalProblem`. """ + def copy(self) -> "LineageProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, time_key: str, From ee567f89d7f29b4a95c5d2cbc81672d470864350 Mon Sep 17 00:00:00 2001 From: lautisilber Date: Thu, 30 May 2024 20:37:39 +0200 Subject: [PATCH 5/6] Fixed minor bug in mappin problem test wher python treats references to lists of basic types in a special way --- src/moscot/base/problems/compound_problem.py | 2 +- tests/problems/_utils.py | 5 +++++ tests/problems/space/test_mapping_problem.py | 2 +- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index e1ed20d1a..b6d4c1b9a 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -71,7 +71,7 @@ def __init__(self, adata: AnnData, **kwargs: Any): self._problem_manager: Optional[ProblemManager[K, B]] = None def __deepcopy__(self, memo) -> "BaseCompoundProblem[K, B]": - vars_to_shallow_copy = ("_adata",) + vars_to_shallow_copy = ("_adata", "_adata_sc") return _copy_depth_helper(self, memo, vars_to_shallow_copy) diff --git a/tests/problems/_utils.py b/tests/problems/_utils.py index 343080b34..7b294df5a 100644 --- a/tests/problems/_utils.py +++ b/tests/problems/_utils.py @@ -17,6 +17,11 @@ def _check_is_copy(o1: object, o2: object, shallow_copy: tuple[str, ...]) -> boo if id(v1) != id(v2): return False else: + if isinstance(v1, list) and all(isinstance(v, (str, int, bool, float)) or v is None for v in v1): + # when deepcopying a list of basic types, python might decide to + # not keep the id of the list the same until it is changed, so it's + # ok to have lists of basic types the same id + continue if id(v1) == id(v2): return False diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 144437480..793e274ad 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -269,7 +269,7 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str assert getattr(geom, val) == args_to_check[arg] def test_copy(self, adata_mapping: AnnData): - shallow_copy = ("_adata",) + shallow_copy = ("_adata", "_adata_sc") prepare_params = {"batch_key": "batch", "sc_attr": {"attr": "X"}} adataref, adatasp = _adata_spatial_split(adata_mapping) From aaae99cc8bb8a53f808db2386391b4179fd69372 Mon Sep 17 00:00:00 2001 From: lautisilber Date: Thu, 30 May 2024 20:50:25 +0200 Subject: [PATCH 6/6] cleaned up method hierarchy --- src/moscot/base/problems/compound_problem.py | 2 +- src/moscot/problems/cross_modality/_translation.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index b6d4c1b9a..07241cfc3 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -71,7 +71,7 @@ def __init__(self, adata: AnnData, **kwargs: Any): self._problem_manager: Optional[ProblemManager[K, B]] = None def __deepcopy__(self, memo) -> "BaseCompoundProblem[K, B]": - vars_to_shallow_copy = ("_adata", "_adata_sc") + vars_to_shallow_copy = ("_adata", "_adata_sc", "_adata_tgt") return _copy_depth_helper(self, memo, vars_to_shallow_copy) diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 2ac6d6d4c..0b3ed90bd 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -13,7 +13,6 @@ QuadInitializer_t, ScaleCost_t, ) -from moscot.base.problems._utils import _copy_depth_helper from moscot.base.problems.compound_problem import B, CompoundProblem, K from moscot.base.problems.problem import OTProblem from moscot.problems._utils import handle_cost, handle_joint_attr @@ -70,14 +69,6 @@ def _create_problem( **kwargs, ) - def __deepcopy__(self, memo) -> "TranslationProblem[K]": - vars_to_shallow_copy = ( - "_adata", - "_adata_tgt", - ) - - return _copy_depth_helper(self, memo, vars_to_shallow_copy) - def copy(self) -> "TranslationProblem[K]": """Create a copy of self.