Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copy problems #706

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/moscot/base/problems/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import functools
import multiprocessing
import threading
Expand All @@ -16,6 +17,7 @@
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

Expand Down Expand Up @@ -730,3 +732,21 @@ def _get_n_cores(n_cores: Optional[int], n_jobs: Optional[int]) -> int:
return multiprocessing.cpu_count() + 1 + n_cores

return n_cores


C = TypeVar("C", bound=object)


def _copy_depth_helper(
original: C, memo: dict[int, Any], shallow_copy: tuple[str, ...] = (), dont_copy: tuple[str, ...] = ()
) -> C:
cls = original.__class__
result = cls.__new__(cls)
memo[id(original)] = result
for k, v in original.__dict__.items():
if k in shallow_copy:
setattr(result, k, v)
memo[id(v)] = v
elif k not in dont_copy:
setattr(result, k, copy.deepcopy(v, memo))
return result
15 changes: 15 additions & 0 deletions src/moscot/base/problems/birth_death.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -159,6 +160,20 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
Keyword arguments for :class:`~moscot.base.problems.OTProblem`.
""" # noqa: D205

def copy(self) -> "BirthDeathProblem":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
if self.stage == "solved":
raise copy.Error("Cannot copy problem that has already been solved.")
return copy.deepcopy(self)

def estimate_marginals(
self, # type: BirthDeathProblemProtocol
adata: AnnData,
Expand Down
26 changes: 25 additions & 1 deletion src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import copy
import types
from typing import (
TYPE_CHECKING,
Expand All @@ -25,7 +26,11 @@
from moscot._logging import logger
from moscot._types import ArrayLike, Policy_t, ProblemStage_t
from moscot.base.output import BaseSolverOutput
from moscot.base.problems._utils import attributedispatch, require_prepare
from moscot.base.problems._utils import (
_copy_depth_helper,
attributedispatch,
require_prepare,
)
from moscot.base.problems.manager import ProblemManager
from moscot.base.problems.problem import BaseProblem, OTProblem
from moscot.utils.subset_policy import (
Expand Down Expand Up @@ -65,6 +70,25 @@ def __init__(self, adata: AnnData, **kwargs: Any):
self._adata = adata
self._problem_manager: Optional[ProblemManager[K, B]] = None

def __deepcopy__(self, memo) -> "BaseCompoundProblem[K, B]":
vars_to_shallow_copy = ("_adata", "_adata_sc", "_adata_tgt")

return _copy_depth_helper(self, memo, vars_to_shallow_copy)

def copy(self) -> "BaseCompoundProblem[K, B]":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
if self.stage == "solved":
raise copy.Error("Cannot copy problem that has already been solved.")
return copy.deepcopy(self)

@abc.abstractmethod
def _create_problem(self, src: K, tgt: K, src_mask: ArrayLike, tgt_mask: ArrayLike, **kwargs: Any) -> B:
"""Create an :term:`OT` subproblem.
Expand Down
15 changes: 15 additions & 0 deletions src/moscot/base/problems/manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import collections
import copy
from typing import (
TYPE_CHECKING,
Dict,
Expand Down Expand Up @@ -40,6 +41,20 @@ def __init__(self, compound_problem: "BaseCompoundProblem[K, B]", policy: Subset
self._policy = policy
self._problems: Dict[Tuple[K, K], B] = {}

def copy(self) -> "ProblemManager[K, B]":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
if self._compound_problem.stage == "solved":
raise copy.Error("Cannot copy problem that has already been solved.")
return copy.deepcopy(self)

def add_problem(
self, key: Tuple[K, K], problem: B, *, overwrite: bool = False, verify_integrity: bool = True
) -> None:
Expand Down
41 changes: 41 additions & 0 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import copy
import pathlib
import types
from typing import (
Expand Down Expand Up @@ -33,6 +34,7 @@
TimeScalesHeatKernel,
_assert_columns_and_index_match,
_assert_series_match,
_copy_depth_helper,
require_solution,
wrap_prepare,
wrap_solve,
Expand Down Expand Up @@ -69,6 +71,19 @@ def prepare(self, *args: Any, **kwargs: Any) -> "BaseProblem":
- :attr:`problem_kind` - kind of the :term:`OT` problem.
"""

@abc.abstractmethod
def copy(self) -> "BaseProblem":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
pass

@abc.abstractmethod
def solve(self, *args: Any, **kwargs: Any) -> "BaseProblem":
"""Solve the problem.
Expand Down Expand Up @@ -259,6 +274,32 @@ def __init__(

self._time_scales_heat_kernel = TimeScalesHeatKernel(None, None, None)

def __deepcopy__(self, memo) -> "OTProblem":
vars_to_shallow_copy = (
"_adata_src",
"_adata_tgt",
"_src_obs_mask",
"_tgt_obs_mask",
"_src_var_mask",
"_tgt_var_mask",
)

return _copy_depth_helper(self, memo, vars_to_shallow_copy)

def copy(self) -> "OTProblem":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
if self.stage == "solved":
raise copy.Error("Cannot copy problem that has already been solved.")
return copy.deepcopy(self)

def _handle_linear(self, cost: CostFn_t = None, **kwargs: Any) -> TaggedArray:
if "x_attr" not in kwargs or "y_attr" not in kwargs:
kwargs.setdefault("tag", Tag.COST_MATRIX)
Expand Down
12 changes: 12 additions & 0 deletions src/moscot/problems/cross_modality/_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,18 @@ def _create_problem(
**kwargs,
)

def copy(self) -> "TranslationProblem[K]":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
return super().copy() # type: ignore

def prepare(
self,
src_attr: Union[str, Mapping[str, Any]],
Expand Down
9 changes: 9 additions & 0 deletions src/moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class SinkhornProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # typ
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

def copy(self) -> "SinkhornProblem[K, B]":
return super().copy() # type: ignore

def prepare(
self,
key: str,
Expand Down Expand Up @@ -260,6 +263,9 @@ class GWProblem(GenericAnalysisMixin[K, B], CompoundProblem[K, B]): # type: ign
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

def copy(self) -> "GWProblem[K, B]":
return super().copy() # type: ignore

def prepare(
self,
key: str,
Expand Down Expand Up @@ -477,6 +483,9 @@ class FGWProblem(GWProblem[K, B]):
Keyword arguments for :class:`~moscot.base.problems.CompoundProblem`.
"""

def copy(self) -> "FGWProblem[K, B]":
return super().copy() # type: ignore

def prepare(
self,
key: str,
Expand Down
12 changes: 12 additions & 0 deletions src/moscot/problems/space/_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ class AlignmentProblem(SpatialAlignmentMixin[K, B], CompoundProblem[K, B]):
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

def copy(self) -> "AlignmentProblem[K, B]":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
return super().copy() # type: ignore

def prepare(
self,
batch_key: str,
Expand Down
12 changes: 12 additions & 0 deletions src/moscot/problems/space/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ def _create_problem(
**kwargs,
)

def copy(self) -> "MappingProblem[K]":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
return super().copy() # type: ignore

def prepare(
self,
sc_attr: Union[str, Mapping[str, Any]],
Expand Down
12 changes: 12 additions & 0 deletions src/moscot/problems/spatiotemporal/_spatio_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ class SpatioTemporalProblem( # type: ignore[misc]
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

def copy(self) -> "SpatioTemporalProblem":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
return super().copy() # type: ignore

def prepare(
self,
time_key: str,
Expand Down
24 changes: 24 additions & 0 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ class TemporalProblem( # type: ignore[misc]
def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

def copy(self) -> "TemporalProblem":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
return super().copy() # type: ignore

def prepare(
self,
time_key: str,
Expand Down Expand Up @@ -267,6 +279,18 @@ class LineageProblem(TemporalProblem):
Keyword arguments for :class:`~moscot.problems.time.TemporalProblem`.
"""

def copy(self) -> "LineageProblem":
"""Create a copy of self.

It deep-copies everything except for the data which is shallow-copied (by reference)
to improve the memory footprint.

Returns
-------
Copy of Self
"""
return super().copy() # type: ignore

def prepare(
self,
time_key: str,
Expand Down
35 changes: 35 additions & 0 deletions tests/problems/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from itertools import combinations


def _check_is_copy(o1: object, o2: object, shallow_copy: tuple[str, ...]) -> bool:
if type(o1) is not type(o2):
return False

for k in o1.__dict__:
v1 = getattr(o1, k)
v2 = getattr(o2, k)
if type(v1) is not type(v2):
return False
# these basic types are treated differently in python and there's no point in comparing their ids
if isinstance(v1, (str, int, bool, float)) or v1 is None:
continue
if k in shallow_copy:
if id(v1) != id(v2):
return False
else:
if isinstance(v1, list) and all(isinstance(v, (str, int, bool, float)) or v is None for v in v1):
# when deepcopying a list of basic types, python might decide to
# not keep the id of the list the same until it is changed, so it's
# ok to have lists of basic types the same id
continue
if id(v1) == id(v2):
return False

return True


def check_is_copy_multiple(os: tuple[object, ...], shallow_copy: tuple[str, ...]) -> bool:
if len(os) < 1:
return False
combs = combinations(os, 2)
return all(_check_is_copy(o1, o2, shallow_copy) for o1, o2 in combs)
Loading