diff --git a/src/moscot/base/problems/_utils.py b/src/moscot/base/problems/_utils.py index 40bf0a99a..a346fa569 100644 --- a/src/moscot/base/problems/_utils.py +++ b/src/moscot/base/problems/_utils.py @@ -1,3 +1,4 @@ +import copy import functools import multiprocessing import threading @@ -16,6 +17,7 @@ Sequence, Tuple, Type, + TypeVar, Union, ) @@ -730,3 +732,21 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int: return multiprocessing.cpu_count() + 1 + n_cores return n_cores + + +C = TypeVar("C", bound=object) + + +def _copy_depth_helper( + original: C, memo: dict[int, Any], shallow_copy: tuple[str, ...] = (), dont_copy: tuple[str, ...] = () +) -> C: + cls = original.__class__ + result = cls.__new__(cls) + memo[id(original)] = result + for k, v in original.__dict__.items(): + if k in shallow_copy: + setattr(result, k, v) + memo[id(v)] = v + elif k not in dont_copy: + setattr(result, k, copy.deepcopy(v, memo)) + return result diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index 6f6004e6c..aeb41771c 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -1,3 +1,4 @@ +import copy from typing import ( TYPE_CHECKING, Any, @@ -159,6 +160,20 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem): Keyword arguments for :class:`~moscot.base.problems.OTProblem`. """ # noqa: D205 + def copy(self) -> "BirthDeathProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + if self.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") + return copy.deepcopy(self) + def estimate_marginals( self, # type: BirthDeathProblemProtocol adata: AnnData, diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index eef27314b..07241cfc3 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -1,4 +1,5 @@ import abc +import copy import types from typing import ( TYPE_CHECKING, @@ -25,7 +26,11 @@ from moscot._logging import logger from moscot._types import ArrayLike, Policy_t, ProblemStage_t from moscot.base.output import BaseSolverOutput -from moscot.base.problems._utils import attributedispatch, require_prepare +from moscot.base.problems._utils import ( + _copy_depth_helper, + attributedispatch, + require_prepare, +) from moscot.base.problems.manager import ProblemManager from moscot.base.problems.problem import BaseProblem, OTProblem from moscot.utils.subset_policy import ( @@ -65,6 +70,25 @@ def __init__(self, adata: AnnData, **kwargs: Any): self._adata = adata self._problem_manager: Optional[ProblemManager[K, B]] = None + def __deepcopy__(self, memo) -> "BaseCompoundProblem[K, B]": + vars_to_shallow_copy = ("_adata", "_adata_sc", "_adata_tgt") + + return _copy_depth_helper(self, memo, vars_to_shallow_copy) + + def copy(self) -> "BaseCompoundProblem[K, B]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + if self.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") + return copy.deepcopy(self) + @abc.abstractmethod def _create_problem(self, src: K, tgt: K, src_mask: ArrayLike, tgt_mask: ArrayLike, **kwargs: Any) -> B: """Create an :term:`OT` subproblem. diff --git a/src/moscot/base/problems/manager.py b/src/moscot/base/problems/manager.py index 1994f27d8..5481a16ff 100644 --- a/src/moscot/base/problems/manager.py +++ b/src/moscot/base/problems/manager.py @@ -1,4 +1,5 @@ import collections +import copy from typing import ( TYPE_CHECKING, Dict, @@ -40,6 +41,20 @@ def __init__(self, compound_problem: "BaseCompoundProblem[K, B]", policy: Subset self._policy = policy self._problems: Dict[Tuple[K, K], B] = {} + def copy(self) -> "ProblemManager[K, B]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + if self._compound_problem.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") + return copy.deepcopy(self) + def add_problem( self, key: Tuple[K, K], problem: B, *, overwrite: bool = False, verify_integrity: bool = True ) -> None: diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index b249a46f6..254080b24 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -1,4 +1,5 @@ import abc +import copy import pathlib import types from typing import ( @@ -33,6 +34,7 @@ TimeScalesHeatKernel, _assert_columns_and_index_match, _assert_series_match, + _copy_depth_helper, require_solution, wrap_prepare, wrap_solve, @@ -69,6 +71,19 @@ def prepare(self, *args: Any, **kwargs: Any) -> "BaseProblem": - :attr:`problem_kind` - kind of the :term:`OT` problem. """ + @abc.abstractmethod + def copy(self) -> "BaseProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + pass + @abc.abstractmethod def solve(self, *args: Any, **kwargs: Any) -> "BaseProblem": """Solve the problem. @@ -259,6 +274,32 @@ def __init__( self._time_scales_heat_kernel = TimeScalesHeatKernel(None, None, None) + def __deepcopy__(self, memo) -> "OTProblem": + vars_to_shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + return _copy_depth_helper(self, memo, vars_to_shallow_copy) + + def copy(self) -> "OTProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + if self.stage == "solved": + raise copy.Error("Cannot copy problem that has already been solved.") + return copy.deepcopy(self) + def _handle_linear(self, cost: CostFn_t = None, **kwargs: Any) -> TaggedArray: if "x_attr" not in kwargs or "y_attr" not in kwargs: kwargs.setdefault("tag", Tag.COST_MATRIX) diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 0f26857df..0b3ed90bd 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -69,6 +69,18 @@ def _create_problem( **kwargs, ) + def copy(self) -> "TranslationProblem[K]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, src_attr: Union[str, Mapping[str, Any]], diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 1466aa998..bc1791079 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -44,6 +44,9 @@ class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # typ def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "SinkhornProblem[K, B]": + return super().copy() # type: ignore + def prepare( self, key: str, @@ -260,6 +263,9 @@ class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ign def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "GWProblem[K, B]": + return super().copy() # type: ignore + def prepare( self, key: str, @@ -477,6 +483,9 @@ class FGWProblem(GWProblem[K, B]): Keyword arguments for :class:`~moscot.base.problems.CompoundProblem`. """ + def copy(self) -> "FGWProblem[K, B]": + return super().copy() # type: ignore + def prepare( self, key: str, diff --git a/src/moscot/problems/space/_alignment.py b/src/moscot/problems/space/_alignment.py index f5a398d23..7b2a2edfe 100644 --- a/src/moscot/problems/space/_alignment.py +++ b/src/moscot/problems/space/_alignment.py @@ -34,6 +34,18 @@ class AlignmentProblem(SpatialAlignmentMixin[K, B], CompoundProblem[K, B]): def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "AlignmentProblem[K, B]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, batch_key: str, diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index f3243344d..ee3b10929 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -70,6 +70,18 @@ def _create_problem( **kwargs, ) + def copy(self) -> "MappingProblem[K]": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, sc_attr: Union[str, Mapping[str, Any]], diff --git a/src/moscot/problems/spatiotemporal/_spatio_temporal.py b/src/moscot/problems/spatiotemporal/_spatio_temporal.py index 6a2e0424f..5b4c381cb 100644 --- a/src/moscot/problems/spatiotemporal/_spatio_temporal.py +++ b/src/moscot/problems/spatiotemporal/_spatio_temporal.py @@ -42,6 +42,18 @@ class SpatioTemporalProblem( # type: ignore[misc] def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "SpatioTemporalProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, time_key: str, diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index fea290f46..c03cbf5d7 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -39,6 +39,18 @@ class TemporalProblem( # type: ignore[misc] def __init__(self, adata: AnnData, **kwargs: Any): super().__init__(adata, **kwargs) + def copy(self) -> "TemporalProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, time_key: str, @@ -267,6 +279,18 @@ class LineageProblem(TemporalProblem): Keyword arguments for :class:`~moscot.problems.time.TemporalProblem`. """ + def copy(self) -> "LineageProblem": + """Create a copy of self. + + It deep-copies everything except for the data which is shallow-copied (by reference) + to improve the memory footprint. + + Returns + ------- + Copy of Self + """ + return super().copy() # type: ignore + def prepare( self, time_key: str, diff --git a/tests/problems/_utils.py b/tests/problems/_utils.py new file mode 100644 index 000000000..7b294df5a --- /dev/null +++ b/tests/problems/_utils.py @@ -0,0 +1,35 @@ +from itertools import combinations + + +def _check_is_copy(o1: object, o2: object, shallow_copy: tuple[str, ...]) -> bool: + if type(o1) is not type(o2): + return False + + for k in o1.__dict__: + v1 = getattr(o1, k) + v2 = getattr(o2, k) + if type(v1) is not type(v2): + return False + # these basic types are treated differently in python and there's no point in comparing their ids + if isinstance(v1, (str, int, bool, float)) or v1 is None: + continue + if k in shallow_copy: + if id(v1) != id(v2): + return False + else: + if isinstance(v1, list) and all(isinstance(v, (str, int, bool, float)) or v is None for v in v1): + # when deepcopying a list of basic types, python might decide to + # not keep the id of the list the same until it is changed, so it's + # ok to have lists of basic types the same id + continue + if id(v1) == id(v2): + return False + + return True + + +def check_is_copy_multiple(os: tuple[object, ...], shallow_copy: tuple[str, ...]) -> bool: + if len(os) < 1: + return False + combs = combinations(os, 2) + return all(_check_is_copy(o1, o2, shallow_copy) for o1, o2 in combs) diff --git a/tests/problems/base/test_compound_problem.py b/tests/problems/base/test_compound_problem.py index 12a75a69f..8d6ae21b5 100644 --- a/tests/problems/base/test_compound_problem.py +++ b/tests/problems/base/test_compound_problem.py @@ -1,3 +1,4 @@ +import copy import os from typing import Any, Literal, Optional, Tuple @@ -15,6 +16,7 @@ from moscot.base.problems import CompoundProblem, OTProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL, Problem +from tests.problems._utils import check_is_copy_multiple class TestCompoundProblem: @@ -286,3 +288,28 @@ def test_save_load_solved(self, adata_time: AnnData): p = Problem.load(file) assert isinstance(p, Problem) + + def test_copy(self, adata_time: AnnData): + shallow_copy = ("_adata",) + + prepare_params = { + "xy": {"x_attr": "X", "y_attr": "X"}, + "key": "time", + "policy": "sequential", + } + solve_params = {"max_iterations": 2} + + prob = Problem(adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index bdfecc7f6..c90f65240 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Literal, Optional, Tuple import pytest @@ -16,6 +17,7 @@ from moscot.base.problems import OTProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL, Geom_t, MockSolverOutput +from tests.problems._utils import check_is_copy_multiple class TestOTProblem: @@ -346,3 +348,35 @@ def test_set_graph_xy_test_t(self, adata_x: AnnData, adata_y: AnnData, t: float) assert pushed_0.shape == pushed_1.shape assert np.all(np.abs(pushed_0 - pushed_1).sum() > np.abs(pushed_2 - pushed_1).sum()) assert np.all(np.abs(pushed_0 - pushed_2).sum() > np.abs(pushed_1 - pushed_2).sum()) + + def test_copy(self, adata_x: AnnData, adata_y: AnnData): + shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + prepare_params = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, + "x": {"attr": "X"}, + "y": {"attr": "X"}, + } + solve_params = {"epsilon": 5e-1, "alpha": 0.5} + + prob = OTProblem(adata_x, adata_y) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index f5199877a..c7c5a76f2 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -1,3 +1,4 @@ +import copy from contextlib import nullcontext from typing import Any, Literal, Mapping, Optional, Tuple @@ -11,6 +12,7 @@ from moscot.backends.ott._utils import alpha_to_fused_penalty from moscot.base.output import BaseSolverOutput from moscot.problems.cross_modality import TranslationProblem +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -180,3 +182,33 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], geom = quad_prob.geom_xy for arg, val in pointcloud_args.items(): assert getattr(geom, val) == args_to_check[arg], arg + + def test_copy(self, adata_translation_split: Tuple[AnnData, AnnData]): + # shallow_copy = ("_adata",) + shallow_copy = ( + "_adata", + "_adata_tgt", + ) + + adata_src, adata_tgt = adata_translation_split + epsilon, alpha, rank, joint_attr = 1e-2, 0.9, -1, {"attr": "obsm", "key": "X_pca"} + src_attr = "emb_src" + tgt_attr = "emb_tgt" + + prepare_params = {"batch_key": "batch", "src_attr": src_attr, "tgt_attr": tgt_attr, "joint_attr": joint_attr} + solve_params = {"epsilon": epsilon, "alpha": alpha, "rank": rank} + + prob = TranslationProblem(adata_src, adata_tgt) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 52d040d51..16efcba51 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Literal, Mapping import pytest @@ -25,6 +26,7 @@ from moscot.base.problems import OTProblem from moscot.problems.generic import FGWProblem from tests._utils import _assert_marginals_set +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -395,3 +397,32 @@ def test_passing_ott_kwargs_quadratic(self, adata_space_rotate: AnnData, warm_st assert solver.warm_start == warm_start assert solver.store_inner_errors == inner_errors + + def test_copy(self, adata_space_rotate: AnnData): + shallow_copy = ("_adata",) + + eps = 0.5 + adata_space_rotate = adata_space_rotate[adata_space_rotate.obs["batch"].isin(("0", "1"))].copy() + prepare_params = { + "key": "batch", + "policy": "sequential", + "joint_attr": "X_pca", + "x_attr": {"attr": "obsm", "key": "spatial"}, + "y_attr": {"attr": "obsm", "key": "spatial"}, + } + solve_params = {"alpha": 0.5, "epsilon": eps} + + prob = FGWProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 5b13c8f2f..99dad637e 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Literal, Mapping import pytest @@ -24,6 +25,7 @@ from moscot.base.problems import OTProblem from moscot.problems.generic import GWProblem from tests._utils import _assert_marginals_set +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( geometry_args, gw_args_1, @@ -359,3 +361,31 @@ def test_passing_ott_kwargs_quadratic(self, adata_space_rotate: AnnData, warm_st assert solver.warm_start == warm_start assert solver.store_inner_errors == inner_errors + + def test_copy(self, adata_space_rotate: AnnData): # type: ignore[no-untyped-def] + shallow_copy = ("_adata",) + + eps = 0.5 + + prepare_params = { + "key": "batch", + "policy": "sequential", + "x_attr": {"attr": "obsm", "key": "spatial"}, + "y_attr": {"attr": "obsm", "key": "spatial"}, + } + solve_params = {"epsilon": eps} + + prob = GWProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index cc8b8af76..660340c6d 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, Literal, Mapping import pytest @@ -23,6 +24,7 @@ from moscot.base.problems import OTProblem from moscot.problems.generic import SinkhornProblem from tests._utils import _assert_marginals_set +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( geometry_args, lin_prob_args, @@ -230,3 +232,25 @@ def test_passing_ott_kwargs(self, adata_time: AnnData, memory: int, refresh: int recenter_potentials = problem[0, 1].solver.solver.recenter_potentials assert recenter_potentials == recenter + + def test_copy(self, adata_time: AnnData, marginal_keys): + shallow_copy = ("_adata",) + + eps = 0.5 + prepare_params = {"key": "time", "a": marginal_keys[0], "b": marginal_keys[1]} + solve_params = {"epsilon": eps} + + prob = SinkhornProblem(adata=adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index d6a78c063..9e21a0867 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path from typing import Any, Literal, Mapping, Optional @@ -14,6 +15,7 @@ from moscot.backends.ott._utils import alpha_to_fused_penalty from moscot.problems.space import AlignmentProblem from moscot.utils.tagged_array import Tag, TaggedArray +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -226,3 +228,29 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_space_rotate: AnnData): + shallow_copy = ("_adata",) + + tau_a, tau_b = [0.8, 1] + marg_a = "a" + marg_b = "b" + adata_space_rotate.obs[marg_a] = adata_space_rotate.obs[marg_b] = np.ones(300) + + prepare_params = {"batch_key": "batch", "a": marg_a, "b": marg_b} + solve_params = {"tau_a": tau_a, "tau_b": tau_b} + + prob = AlignmentProblem(adata=adata_space_rotate) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 903c02611..793e274ad 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -16,6 +16,7 @@ from moscot.problems.space import MappingProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import _adata_spatial_split +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -266,3 +267,20 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_mapping: AnnData): + shallow_copy = ("_adata", "_adata_sc") + + prepare_params = {"batch_key": "batch", "sc_attr": {"attr": "X"}} + adataref, adatasp = _adata_spatial_split(adata_mapping) + + prob = MappingProblem(adataref, adatasp) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index a22d5cd35..5dd19b9e0 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, List, Mapping import pytest @@ -13,6 +14,7 @@ from moscot.base.problems import BirthDeathProblem from moscot.problems.spatiotemporal import SpatioTemporalProblem from tests._utils import ATOL, RTOL +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -234,3 +236,27 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_spatio_temporal: AnnData): + shallow_copy = ("_adata",) + + eps = 1 + alpha = 0.5 + + prepare_params = {"time_key": "time", "spatial_key": "spatial"} + solve_params = {"alpha": alpha, "epsilon": eps} + + prob = SpatioTemporalProblem(adata=adata_spatio_temporal) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index fed68cf6f..04967d1d6 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, List, Mapping import pytest @@ -12,6 +13,7 @@ from moscot.base.problems import BirthDeathProblem from moscot.problems.time import LineageProblem from tests._utils import ATOL, RTOL +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( fgw_args_1, fgw_args_2, @@ -269,3 +271,30 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi for arg, val in pointcloud_args.items(): assert hasattr(geom, val) assert getattr(geom, val) == args_to_check[arg] + + def test_copy(self, adata_time_barcodes: AnnData): + shallow_copy = ("_adata",) + + eps, key = 0.5, (0, 1) + adata_time_barcodes = adata_time_barcodes[adata_time_barcodes.obs["time"].isin(key)].copy() + prepare_params = { + "time_key": "time", + "policy": "sequential", + "lineage_attr": {"attr": "obsm", "key": "barcodes", "tag": "cost_matrix", "cost": "barcode_distance"}, + } + solve_params = {"epsilon": eps} + + prob = LineageProblem(adata=adata_time_barcodes) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy() diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index 24dee7e42..6337c32ac 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -7,6 +7,7 @@ from anndata import AnnData from moscot.base.problems import BirthDeathProblem +from tests.problems._utils import check_is_copy_multiple # TODO(@MUCDK) put file in different folder according to moscot.problems structure @@ -135,3 +136,38 @@ def test_marginal_kwargs(self, adata_time_marginal_estimations: AnnData, margina assert not np.allclose(gr1, gr2) else: assert np.allclose(gr1, gr2) + + def test_copy(self, adata_time_marginal_estimations: AnnData): + t1, t2 = 0, 1 + adata_x = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t1] + adata_y = adata_time_marginal_estimations[adata_time_marginal_estimations.obs["time"] == t2] + + shallow_copy = ( + "_adata_src", + "_adata_tgt", + "_src_obs_mask", + "_tgt_obs_mask", + "_src_var_mask", + "_tgt_var_mask", + ) + + prepare_params = { + "xy": {}, + "x": {"attr": "X"}, + "y": {"attr": "X"}, + "a": True, + "b": True, + "proliferation_key": "proliferation", + "apoptosis_key": "apoptosis", + } + + prob = BirthDeathProblem(adata_x, adata_y, src_key=t1, tgt_key=t2) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index 7858eb613..0632feaab 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -1,3 +1,4 @@ +import copy from typing import Any, List, Mapping, Optional import pytest @@ -18,6 +19,7 @@ from moscot.problems.time import TemporalProblem from moscot.utils.tagged_array import Tag, TaggedArray from tests._utils import ATOL, RTOL +from tests.problems._utils import check_is_copy_multiple from tests.problems.conftest import ( geometry_args, lin_prob_args, @@ -470,3 +472,25 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A assert type(el) == type(args_to_check[arg]) # noqa: E721 else: assert el == args_to_check[arg] + + def test_copy(self, adata_time: AnnData): + shallow_copy = ("_adata",) + + eps = 0.5 + prepare_params = {"time_key": "time", "cost": "cosine", "xy_callback": None, "joint_attr": "X_pca"} + solve_params = {"epsilon": eps} + + prob = TemporalProblem(adata=adata_time) + prob_copy_1 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1), shallow_copy) + + prob = prob.prepare(**prepare_params) # type: ignore + prob_copy_1 = prob_copy_1.prepare(**prepare_params) # type: ignore + prob_copy_2 = prob.copy() + + assert check_is_copy_multiple((prob, prob_copy_1, prob_copy_2), shallow_copy) + + prob = prob.solve(**solve_params) # type: ignore + with pytest.raises(copy.Error): + _ = prob.copy()