Skip to content

Commit

Permalink
Merge pull request #766 from theislab/refactor/remove-protocol-classes
Browse files Browse the repository at this point in the history
Refactor/remove protocol classes
  • Loading branch information
selmanozleyen authored Nov 26, 2024
2 parents ed48242 + 901507d commit 4080a94
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 451 deletions.
121 changes: 21 additions & 100 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Union,
)
Expand All @@ -21,11 +20,9 @@
from scipy.sparse.linalg import LinearOperator

import scanpy as sc
from anndata import AnnData

from moscot import _constants
from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.problems._utils import (
_check_argument_compatibility_cell_transition,
_correlation_test,
Expand All @@ -34,104 +31,28 @@
_validate_annotations,
_validate_args_cell_transition,
)
from moscot.base.problems.compound_problem import ApplyOutput_t, B, K
from moscot.base.problems.compound_problem import B, K
from moscot.base.problems.problem import (
AbstractPushPullAdata,
AbstractSolutionsProblems,
)
from moscot.plotting._utils import set_plotting_vars
from moscot.utils.data import transcription_factors
from moscot.utils.subset_policy import SubsetPolicy

__all__ = ["AnalysisMixin"]


class AnalysisMixinProtocol(Protocol[K, B]):
"""Protocol class."""

adata: AnnData
_policy: SubsetPolicy[K]
solutions: dict[tuple[K, K], BaseDiscreteSolverOutput]
problems: dict[tuple[K, K], B]

def _apply(
self,
data: Optional[Union[str, ArrayLike]] = None,
source: Optional[K] = None,
target: Optional[K] = None,
forward: bool = True,
return_all: bool = False,
scale_by_marginals: bool = False,
**kwargs: Any,
) -> ApplyOutput_t[K]: ...

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
) -> LinearOperator: ...

def _flatten(
self: AnalysisMixinProtocol[K, B],
data: dict[K, ArrayLike],
*,
key: Optional[str],
) -> ArrayLike: ...

def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Push distribution."""
...

def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Pull distribution."""
...

def _cell_transition(
self: AnalysisMixinProtocol[K, B],
source: K,
target: K,
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
key_added: Optional[str] = _constants.CELL_TRANSITION,
**kwargs: Any,
) -> pd.DataFrame: ...

def _cell_transition_online(
self: AnalysisMixinProtocol[K, B],
key: Optional[str],
source: K,
target: K,
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic
aggregation_mode: Literal["annotation", "cell"] = "annotation",
other_key: Optional[str] = None,
other_adata: Optional[str] = None,
batch_size: Optional[int] = None,
normalize: bool = True,
) -> pd.DataFrame: ...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: K,
target: K,
key: str | None = None,
other_adata: Optional[str] = None,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame: ...


class AnalysisMixin(Generic[K, B]):
class AnalysisMixin(Generic[K, B], AbstractPushPullAdata, AbstractSolutionsProblems):
"""Base Analysis Mixin."""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def _cell_transition(
self: AnalysisMixinProtocol[K, B],
self,
source: K,
target: K,
target: Optional[K],
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
Expand Down Expand Up @@ -178,7 +99,7 @@ def _cell_transition(
return tm

def _annotation_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
self,
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
Expand Down Expand Up @@ -210,10 +131,10 @@ def _annotation_aggregation_transition(
)

def _cell_transition_online(
self: AnalysisMixinProtocol[K, B],
self,
key: Optional[str],
source: K,
target: K,
target: Optional[K],
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic
Expand Down Expand Up @@ -272,7 +193,7 @@ def _cell_transition_online(
split_mass=False,
**move_op_const_kwargs,
)
tm = self._annotation_aggregation_transition( # type: ignore[attr-defined]
tm = self._annotation_aggregation_transition(
annotations_1=source_annotations_verified if forward else target_annotations_verified,
annotations_2=target_annotations_verified if forward else source_annotations_verified,
df=df_to,
Expand All @@ -286,7 +207,7 @@ def _cell_transition_online(
split_mass=True,
**move_op_const_kwargs,
)
tm = self._cell_aggregation_transition( # type: ignore[attr-defined]
tm = self._cell_aggregation_transition(
df_from=df_from,
df_to=df_to,
annotations=target_annotations_verified if forward else source_annotations_verified,
Expand All @@ -309,7 +230,7 @@ def _cell_transition_online(
)

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
self,
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: K,
Expand Down Expand Up @@ -394,7 +315,7 @@ def _annotation_mapping(
raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.")

def _sample_from_tmap(
self: AnalysisMixinProtocol[K, B],
self,
source: K,
target: K,
n_samples: int,
Expand Down Expand Up @@ -472,7 +393,7 @@ def _sample_from_tmap(
return rows, all_cols_sampled # type: ignore[return-value]

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
self,
# TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
Expand All @@ -485,15 +406,15 @@ def _interpolate_transport(
fst, *rest = path
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
def _flatten(self, data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
for k, v in data.items():
mask = self.adata.obs[key] == k
tmp[mask] = np.squeeze(v)
return tmp

def _cell_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
self,
df_from: pd.DataFrame,
df_to: pd.DataFrame,
annotations: list[Any],
Expand Down Expand Up @@ -540,7 +461,7 @@ def _cell_aggregation_transition(
# adapted from:
# https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392
def compute_feature_correlation(
self: AnalysisMixinProtocol[K, B],
self,
obs_key: str,
corr_method: Literal["pearson", "spearman"] = "pearson",
significance_method: Literal["fisher", "perm_test"] = "fisher",
Expand Down Expand Up @@ -649,7 +570,7 @@ def compute_feature_correlation(
)

def compute_entropy(
self: AnalysisMixinProtocol[K, B],
self,
source: K,
target: K,
forward: bool = True,
Expand Down Expand Up @@ -706,7 +627,7 @@ def compute_entropy(
split_mass=True,
key_added=None,
)
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator]
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs)
if key_added is not None:
self.adata.obs[key_added] = df
return df if key_added is None else None
51 changes: 8 additions & 43 deletions src/moscot/base/problems/birth_death.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Protocol,
Sequence,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union

import numpy as np

Expand All @@ -17,39 +8,13 @@

from moscot._logging import logger
from moscot._types import ArrayLike
from moscot.base.problems.problem import OTProblem
from moscot.base.problems.problem import AbstractAdataAccess, OTProblem
from moscot.utils.data import apoptosis_markers, proliferation_markers

__all__ = ["BirthDeathProblem", "BirthDeathMixin"]


class BirthDeathProtocol(Protocol): # noqa: D101
adata: AnnData
proliferation_key: Optional[str]
apoptosis_key: Optional[str]
_proliferation_key: Optional[str]
_apoptosis_key: Optional[str]
_scaling: float
_prior_growth: Optional[ArrayLike]

def score_genes_for_marginals( # noqa: D102
self: "BirthDeathProtocol",
gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
proliferation_key: str = "proliferation",
apoptosis_key: str = "apoptosis",
**kwargs: Any,
) -> "BirthDeathProtocol": ...


class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101
delta: float
adata_tgt: AnnData
a: Optional[ArrayLike]
b: Optional[ArrayLike]


class BirthDeathMixin:
class BirthDeathMixin(AbstractAdataAccess):
"""Mixin class used to estimate cell proliferation and apoptosis.
Parameters
Expand All @@ -68,7 +33,7 @@ def __init__(self, *args: Any, **kwargs: Any):
self._prior_growth: Optional[ArrayLike] = None

def score_genes_for_marginals(
self, # type: BirthDeathProblemProtocol
self,
gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
proliferation_key: str = "proliferation",
Expand Down Expand Up @@ -123,15 +88,15 @@ def score_genes_for_marginals(
"At least one of `gene_set_proliferation` or `gene_set_apoptosis` must be provided to score genes."
)

return self # type: ignore[return-value]
return self

@property
def proliferation_key(self) -> Optional[str]:
"""Key in :attr:`~anndata.AnnData.obs` where cell proliferation is stored."""
return self._proliferation_key

@proliferation_key.setter
def proliferation_key(self: BirthDeathProtocol, key: Optional[str]) -> None:
def proliferation_key(self, key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find proliferation data in `adata.obs[{key!r}]`.")
self._proliferation_key = key
Expand All @@ -142,7 +107,7 @@ def apoptosis_key(self) -> Optional[str]:
return self._apoptosis_key

@apoptosis_key.setter
def apoptosis_key(self: BirthDeathProtocol, key: Optional[str]) -> None:
def apoptosis_key(self, key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find apoptosis data in `adata.obs[{key!r}]`.")
self._apoptosis_key = key
Expand All @@ -161,7 +126,7 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
""" # noqa: D205

def estimate_marginals(
self, # type: BirthDeathProblemProtocol
self,
adata: AnnData,
source: bool,
proliferation_key: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]]


class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
class BaseCompoundProblem(BaseProblem, Generic[K, B], abc.ABC):
"""Base class for all biological problems.
This class translates a biological problem to multiple :term:`OT` problems.
Expand Down
Loading

0 comments on commit 4080a94

Please sign in to comment.