From d40b7affbb60cc66647e70e75dc6fae6f5a919d8 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 20 Nov 2024 15:16:24 +0100 Subject: [PATCH 1/8] remove protocol classes --- src/moscot/base/problems/_mixins.py | 105 ++---------- src/moscot/base/problems/birth_death.py | 45 +---- src/moscot/problems/cross_modality/_mixins.py | 22 +-- src/moscot/problems/generic/_mixins.py | 29 +--- src/moscot/problems/space/_mixins.py | 80 ++------- src/moscot/problems/time/_mixins.py | 155 +++--------------- 6 files changed, 60 insertions(+), 376 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 2a5ba2029..fc762d233 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -11,7 +11,6 @@ Literal, Mapping, Optional, - Protocol, Sequence, Union, ) @@ -21,11 +20,9 @@ from scipy.sparse.linalg import LinearOperator import scanpy as sc -from anndata import AnnData from moscot import _constants from moscot._types import ArrayLike, Numeric_t, Str_Dict_t -from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems._utils import ( _check_argument_compatibility_cell_transition, _correlation_test, @@ -34,7 +31,7 @@ _validate_annotations, _validate_args_cell_transition, ) -from moscot.base.problems.compound_problem import ApplyOutput_t, B, K +from moscot.base.problems.compound_problem import B, K from moscot.plotting._utils import set_plotting_vars from moscot.utils.data import transcription_factors from moscot.utils.subset_policy import SubsetPolicy @@ -42,86 +39,6 @@ __all__ = ["AnalysisMixin"] -class AnalysisMixinProtocol(Protocol[K, B]): - """Protocol class.""" - - adata: AnnData - _policy: SubsetPolicy[K] - solutions: dict[tuple[K, K], BaseDiscreteSolverOutput] - problems: dict[tuple[K, K], B] - - def _apply( - self, - data: Optional[Union[str, ArrayLike]] = None, - source: Optional[K] = None, - target: Optional[K] = None, - forward: bool = True, - return_all: bool = False, - scale_by_marginals: bool = False, - **kwargs: Any, - ) -> ApplyOutput_t[K]: ... - - def _interpolate_transport( - self: AnalysisMixinProtocol[K, B], - path: Sequence[tuple[K, K]], - scale_by_marginals: bool = True, - ) -> LinearOperator: ... - - def _flatten( - self: AnalysisMixinProtocol[K, B], - data: dict[K, ArrayLike], - *, - key: Optional[str], - ) -> ArrayLike: ... - - def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: - """Push distribution.""" - ... - - def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: - """Pull distribution.""" - ... - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - aggregation_mode: Literal["annotation", "cell"] = "annotation", - key_added: Optional[str] = _constants.CELL_TRANSITION, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _cell_transition_online( - self: AnalysisMixinProtocol[K, B], - key: Optional[str], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, - other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, - ) -> pd.DataFrame: ... - - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - mapping_mode: Literal["sum", "max"], - annotation_label: str, - forward: bool, - source: K, - target: K, - key: str | None = None, - other_adata: Optional[str] = None, - scale_by_marginals: bool = True, - cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - ) -> pd.DataFrame: ... - - class AnalysisMixin(Generic[K, B]): """Base Analysis Mixin.""" @@ -129,7 +46,7 @@ def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) def _cell_transition( - self: AnalysisMixinProtocol[K, B], + self, source: K, target: K, source_groups: Str_Dict_t, @@ -178,7 +95,7 @@ def _cell_transition( return tm def _annotation_aggregation_transition( - self: AnalysisMixinProtocol[K, B], + self, annotations_1: list[Any], annotations_2: list[Any], df: pd.DataFrame, @@ -210,7 +127,7 @@ def _annotation_aggregation_transition( ) def _cell_transition_online( - self: AnalysisMixinProtocol[K, B], + self, key: Optional[str], source: K, target: K, @@ -309,7 +226,7 @@ def _cell_transition_online( ) def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, @@ -394,7 +311,7 @@ def _annotation_mapping( raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") def _sample_from_tmap( - self: AnalysisMixinProtocol[K, B], + self, source: K, target: K, n_samples: int, @@ -472,7 +389,7 @@ def _sample_from_tmap( return rows, all_cols_sampled # type: ignore[return-value] def _interpolate_transport( - self: AnalysisMixinProtocol[K, B], + self, # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, @@ -485,7 +402,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: + def _flatten(self, data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -493,7 +410,7 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key return tmp def _cell_aggregation_transition( - self: AnalysisMixinProtocol[K, B], + self, df_from: pd.DataFrame, df_to: pd.DataFrame, annotations: list[Any], @@ -540,7 +457,7 @@ def _cell_aggregation_transition( # adapted from: # https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392 def compute_feature_correlation( - self: AnalysisMixinProtocol[K, B], + self, obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", @@ -649,7 +566,7 @@ def compute_feature_correlation( ) def compute_entropy( - self: AnalysisMixinProtocol[K, B], + self, source: K, target: K, forward: bool = True, diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index 27d667b49..cec80c8a8 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -1,14 +1,5 @@ from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Optional, - Protocol, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union import numpy as np @@ -23,32 +14,6 @@ __all__ = ["BirthDeathProblem", "BirthDeathMixin"] -class BirthDeathProtocol(Protocol): # noqa: D101 - adata: AnnData - proliferation_key: Optional[str] - apoptosis_key: Optional[str] - _proliferation_key: Optional[str] - _apoptosis_key: Optional[str] - _scaling: float - _prior_growth: Optional[ArrayLike] - - def score_genes_for_marginals( # noqa: D102 - self: "BirthDeathProtocol", - gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, - gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, - proliferation_key: str = "proliferation", - apoptosis_key: str = "apoptosis", - **kwargs: Any, - ) -> "BirthDeathProtocol": ... - - -class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101 - delta: float - adata_tgt: AnnData - a: Optional[ArrayLike] - b: Optional[ArrayLike] - - class BirthDeathMixin: """Mixin class used to estimate cell proliferation and apoptosis. @@ -68,7 +33,7 @@ def __init__(self, *args: Any, **kwargs: Any): self._prior_growth: Optional[ArrayLike] = None def score_genes_for_marginals( - self, # type: BirthDeathProblemProtocol + self, gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, proliferation_key: str = "proliferation", @@ -131,7 +96,7 @@ def proliferation_key(self) -> Optional[str]: return self._proliferation_key @proliferation_key.setter - def proliferation_key(self: BirthDeathProtocol, key: Optional[str]) -> None: + def proliferation_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find proliferation data in `adata.obs[{key!r}]`.") self._proliferation_key = key @@ -142,7 +107,7 @@ def apoptosis_key(self) -> Optional[str]: return self._apoptosis_key @apoptosis_key.setter - def apoptosis_key(self: BirthDeathProtocol, key: Optional[str]) -> None: + def apoptosis_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find apoptosis data in `adata.obs[{key!r}]`.") self._apoptosis_key = key @@ -161,7 +126,7 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem): """ # noqa: D205 def estimate_marginals( - self, # type: BirthDeathProblemProtocol + self, adata: AnnData, source: bool, proliferation_key: Optional[str] = None, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 5a3967e1f..d436fe1e6 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -7,26 +7,12 @@ from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K __all__ = ["CrossModalityTranslationMixin"] -class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]): - """Protocol class.""" - - adata_src: AnnData - adata_tgt: AnnData - _src_attr: Optional[Dict[str, Any]] - _tgt_attr: Optional[Dict[str, Any]] - batch_key: Optional[str] - - def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - class CrossModalityTranslationMixin(AnalysisMixin[K, B]): """Cross modality translation analysis mixin class.""" @@ -37,7 +23,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None def translate( # type: ignore[misc] - self: CrossModalityTranslationMixinProtocol[K, B], + self, source: K, target: K, forward: bool = True, @@ -107,7 +93,7 @@ def _get_features( return prob.push(_get_features(adata_src, attr=src_attr), **kwargs) def cell_transition( # type: ignore[misc] - self: CrossModalityTranslationMixinProtocol[K, B], + self, source: K, target: Optional[K] = None, source_groups: Optional[Str_Dict_t] = None, @@ -186,7 +172,7 @@ def cell_transition( # type: ignore[misc] ) def annotation_mapping( # type: ignore[misc] - self: CrossModalityTranslationMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index ed74ebffd..a2f18c176 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -1,32 +1,17 @@ -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Protocol, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union import pandas as pd -from anndata import AnnData from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import ApplyOutput_t, B, K from moscot.plotting._utils import set_plotting_vars __all__ = ["GenericAnalysisMixin"] -class GenericAnalysisMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): - """Protocol class.""" - - _batch_key: Optional[str] - batch_key: Optional[str] - adata: AnnData - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - class GenericAnalysisMixin(AnalysisMixin[K, B]): """Generic Analysis Mixin.""" @@ -35,7 +20,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None def cell_transition( - self: GenericAnalysisMixinProtocol[K, B], + self, source: K, target: K, source_groups: Optional[Str_Dict_t] = None, @@ -112,7 +97,7 @@ def cell_transition( ) def push( - self: GenericAnalysisMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -181,7 +166,7 @@ def push( return result def pull( - self: GenericAnalysisMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -247,12 +232,12 @@ def pull( return result @property - def batch_key(self: GenericAnalysisMixinProtocol[K, B]) -> Optional[str]: + def batch_key(self) -> Optional[str]: """Batch key in :attr:`~anndata.AnnData.obs`.""" return self._batch_key @batch_key.setter - def batch_key(self: GenericAnalysisMixinProtocol[K, B], key: Optional[str]) -> None: + def batch_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 3dec5ce54..1938ae06e 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -28,65 +28,13 @@ from moscot import _constants from moscot._logging import logger from moscot._types import ArrayLike, Device_t, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K from moscot.utils.subset_policy import StarPolicy __all__ = ["SpatialAlignmentMixin", "SpatialMappingMixin"] -class SpatialAlignmentMixinProtocol(AnalysisMixinProtocol[K, B]): - """Protocol class.""" - - spatial_key: Optional[str] - _spatial_key: Optional[str] - batch_key: Optional[str] - - def _subset_spatial( # type:ignore[empty-body] - self: "SpatialAlignmentMixinProtocol[K, B]", - k: K, - spatial_key: str, - ) -> ArrayLike: ... - - def _interpolate_scheme( # type:ignore[empty-body] - self: "SpatialAlignmentMixinProtocol[K, B]", - reference: K, - mode: Literal["warp", "affine"], - spatial_key: str, - ) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]: ... - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - -class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]): - """Protocol class.""" - - adata_sc: AnnData - adata_sp: AnnData - batch_key: Optional[str] - spatial_key: Optional[str] - _spatial_key: Optional[str] - - def _filter_vars( - self: "SpatialMappingMixinProtocol[K, B]", - var_names: Optional[Sequence[str]] = None, - ) -> Optional[List[str]]: ... - - def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - class SpatialAlignmentMixin(AnalysisMixin[K, B]): """Spatial alignment mixin class.""" @@ -96,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None def _interpolate_scheme( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + self, reference: K, mode: Literal["warp", "affine"], spatial_key: str, @@ -140,7 +88,7 @@ def _interpolate_scheme( # type: ignore[misc] return transport_maps, (transport_metadata if mode == "affine" else None) def align( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + self, reference: Optional[K] = None, mode: Literal["warp", "affine"] = "warp", spatial_key: Optional[str] = None, @@ -200,7 +148,7 @@ def align( # type: ignore[misc] self.adata.uns[key_added]["alignment_metadata"] = aligned_metadata # noqa: RET503 def cell_transition( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + self, source: K, target: K, source_groups: Optional[Str_Dict_t] = None, @@ -278,7 +226,7 @@ def cell_transition( # type: ignore[misc] ) def annotation_mapping( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, @@ -337,7 +285,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: # type: ignore[misc] + def spatial_key(self, key: Optional[str]) -> None: # type: ignore[misc] if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key @@ -354,7 +302,7 @@ def batch_key(self, key: Optional[str]) -> None: self._batch_key = key def _subset_spatial( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + self, k: K, spatial_key: str, ) -> ArrayLike: @@ -371,7 +319,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._spatial_key: Optional[str] = None def _filter_vars( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + self, var_names: Optional[Sequence[str]] = None, ) -> Optional[List[str]]: """Filter variables for the linear term.""" @@ -393,7 +341,7 @@ def _filter_vars( # type: ignore[misc] raise ValueError("Some variable are missing in the single-cell or the spatial `AnnData`.") def correlate( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + self, var_names: Optional[Sequence[str]] = None, corr_method: Literal["pearson", "spearman"] = "pearson", device: Optional[Device_t] = None, @@ -492,7 +440,7 @@ def correlate( # type: ignore[misc] return corrs def impute( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + self, var_names: Optional[Sequence[str]] = None, device: Optional[Device_t] = None, batch_size: Optional[int] = None, @@ -547,7 +495,7 @@ def impute( # type: ignore[misc] return adata_pred def spatial_correspondence( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + self, interval: Union[int, ArrayLike] = 10, max_dist: Optional[int] = None, attr: Optional[Dict[str, Optional[str]]] = None, @@ -608,7 +556,7 @@ def _get_features( return res def cell_transition( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + self, source: K, target: Optional[K] = None, source_groups: Optional[Str_Dict_t] = None, @@ -685,7 +633,7 @@ def cell_transition( # type: ignore[misc] ) def annotation_mapping( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, @@ -756,7 +704,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: # type: ignore[misc] + def spatial_key(self, key: Optional[str]) -> None: # type: ignore[misc] if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 0fde8a334..aae926a22 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1,17 +1,13 @@ from __future__ import annotations import itertools -import pathlib import types from typing import ( TYPE_CHECKING, Any, - Iterable, - Iterator, Literal, Mapping, Optional, - Protocol, Sequence, Union, ) @@ -24,8 +20,7 @@ from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol -from moscot.base.problems.birth_death import BirthDeathProblem +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import ApplyOutput_t, B, K from moscot.plotting._utils import set_plotting_vars from moscot.utils.tagged_array import Tag @@ -33,118 +28,6 @@ __all__ = ["TemporalMixin"] -class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type: ignore[misc] - adata: AnnData - problems: dict[tuple[K, K], BirthDeathProblem] - temporal_key: Optional[str] - _temporal_key: Optional[str] - - def cell_transition( # noqa: D102 - self: TemporalMixinProtocol[K, B], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - batch_size: Optional[int] = None, - normalize: bool = True, - key_added: Optional[str] = _constants.CELL_TRANSITION, - ) -> pd.DataFrame: ... - - def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... - - def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _sample_from_tmap( - self: TemporalMixinProtocol[K, B], - source: K, - target: K, - n_samples: int, - source_dim: int, - target_dim: int, - batch_size: int = 256, - account_for_unbalancedness: bool = False, - interpolation_parameter: Optional[float] = None, - seed: Optional[int] = None, - ) -> tuple[list[Any], list[ArrayLike]]: ... - - def _compute_wasserstein_distance( - self: TemporalMixinProtocol[K, B], - point_cloud_1: ArrayLike, - point_cloud_2: ArrayLike, - a: Optional[ArrayLike] = None, - b: Optional[ArrayLike] = None, - backend: Literal["ott"] = "ott", - **kwargs: Any, - ) -> float: ... - - def _interpolate_gex_with_ot( - self: TemporalMixinProtocol[K, B], - number_cells: int, - source_data: ArrayLike, - target_data: ArrayLike, - source: K, - target: K, - interpolation_parameter: float, - account_for_unbalancedness: bool = True, - batch_size: int = 256, - seed: Optional[int] = None, - ) -> ArrayLike: ... - - def _get_data( - self: TemporalMixinProtocol[K, B], - source: K, - intermediate: Optional[K] = None, - target: Optional[K] = None, - posterior_marginals: bool = True, - *, - only_start: bool = False, - ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: ... - - def _interpolate_gex_randomly( - self: TemporalMixinProtocol[K, B], - number_cells: int, - source_data: ArrayLike, - target_data: ArrayLike, - interpolation_parameter: float, - growth_rates: Optional[ArrayLike] = None, - seed: Optional[int] = None, - ) -> ArrayLike: ... - - def _plot_temporal( - self: TemporalMixinProtocol[K, B], - data: dict[K, ArrayLike], - source: K, - target: K, - time_points: Optional[Iterable[K]] = None, - basis: str = "umap", - result_key: Optional[str] = None, - fill_value: float = 0.0, - save: Optional[Union[str, pathlib.Path]] = None, - **kwargs: Any, - ) -> None: ... - - @staticmethod - def _get_interp_param( - source: K, intermediate: K, target: K, interpolation_parameter: Optional[float] = None - ) -> float: ... - - def __iter__(self) -> Iterator[tuple[K, K]]: ... - - class TemporalMixin(AnalysisMixin[K, B]): """Analysis Mixin for all problems involving a temporal dimension.""" @@ -153,7 +36,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._temporal_key: Optional[str] = None def cell_transition( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, source_groups: Str_Dict_t, @@ -228,7 +111,7 @@ def cell_transition( ) def annotation_mapping( - self: TemporalMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, @@ -283,7 +166,7 @@ def annotation_mapping( ) def sankey( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, source_groups: Str_Dict_t, @@ -405,7 +288,7 @@ def sankey( set_plotting_vars(self.adata, _constants.SANKEY, key=key_added, value=plot_vars) # noqa: RET503 def push( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -472,7 +355,7 @@ def push( return result def pull( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -538,7 +421,7 @@ def pull( return result @property - def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def prior_growth_rates(self) -> Optional[pd.DataFrame]: """Prior estimate of the source growth rates.""" computed = [isinstance(p.prior_growth_rates, np.ndarray) for p in self.problems.values()] if not np.sum(computed): @@ -556,7 +439,7 @@ def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFra return pd.concat([df_1, df_2], verify_integrity=True) @property - def posterior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def posterior_growth_rates(self) -> Optional[pd.DataFrame]: """Posterior estimate of the source growth rates.""" computed = [isinstance(p.posterior_growth_rates, np.ndarray) for p in self.problems.values()] if not np.sum(computed): @@ -574,7 +457,7 @@ def posterior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.Dat return pd.concat([df_1, df_2], verify_integrity=True) @property - def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def cell_costs_source(self) -> Optional[pd.DataFrame]: """Cell cost obtained by the :term:`first dual potential `. Only available for subproblems with :attr:`problem_kind = 'linear' `. @@ -599,7 +482,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram return pd.concat([df_1, df_2], verify_integrity=True) @property - def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def cell_costs_target(self) -> Optional[pd.DataFrame]: """Cell cost obtained by the :term:`second dual potential `. Only available for subproblems with :attr:`problem_kind = 'linear' `. @@ -624,7 +507,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram return pd.concat([df_1, df_2], verify_integrity=True) def _get_data( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: Optional[K] = None, target: Optional[K] = None, @@ -673,7 +556,7 @@ def _get_data( ) def compute_interpolated_distance( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: K, target: K, @@ -757,7 +640,7 @@ def compute_interpolated_distance( return self._compute_wasserstein_distance(intermediate_data, interpolation, backend=backend, **kwargs) def compute_random_distance( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: K, target: K, @@ -831,7 +714,7 @@ def compute_random_distance( return self._compute_wasserstein_distance(intermediate_data, random_interpolation, backend=backend, **kwargs) def compute_time_point_distances( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: K, target: K, @@ -884,7 +767,7 @@ def compute_time_point_distances( return distance_source_intermediate, distance_intermediate_target def compute_batch_distances( - self: TemporalMixinProtocol[K, B], + self, time: K, batch_key: str, posterior_marginals: bool = True, @@ -936,7 +819,7 @@ def compute_batch_distances( # TODO(@MUCDK) possibly offer two alternatives, once exact EMD with POT backend and once approximate, # faster with same solver as used for original problems def _compute_wasserstein_distance( - self: TemporalMixinProtocol[K, B], + self, point_cloud_1: ArrayLike, point_cloud_2: ArrayLike, a: Optional[ArrayLike] = None, @@ -951,7 +834,7 @@ def _compute_wasserstein_distance( raise NotImplementedError("Only `ott` available as backend.") def _interpolate_gex_with_ot( - self: TemporalMixinProtocol[K, B], + self, number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -979,7 +862,7 @@ def _interpolate_gex_with_ot( ) def _interpolate_gex_randomly( - self: TemporalMixinProtocol[K, B], + self, number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -1026,7 +909,7 @@ def temporal_key(self) -> Optional[str]: return self._temporal_key @temporal_key.setter - def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None: + def temporal_key(self, key: Optional[str]) -> None: if key is None: self._temporal_key = key return From 55166d79fc1f43a39c3a58dc0ca392d44e3f5d09 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 20 Nov 2024 15:17:19 +0100 Subject: [PATCH 2/8] formatting --- src/moscot/problems/generic/_mixins.py | 1 - src/moscot/problems/time/_mixins.py | 10 +--------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index a2f18c176..a2a90a0a8 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -2,7 +2,6 @@ import pandas as pd - from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t from moscot.base.problems._mixins import AnalysisMixin diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index aae926a22..68b2900b3 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -2,15 +2,7 @@ import itertools import types -from typing import ( - TYPE_CHECKING, - Any, - Literal, - Mapping, - Optional, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union import numpy as np import pandas as pd From be8460a7414bc290452cd7de51afd88e7418ece1 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 25 Nov 2024 15:34:42 +0100 Subject: [PATCH 3/8] a demonstration of how ABC's would look like --- src/moscot/base/problems/_mixins.py | 16 ++-- src/moscot/base/problems/birth_death.py | 6 +- src/moscot/base/problems/compound_problem.py | 2 +- src/moscot/base/problems/problem.py | 88 +++++++++++++++++++ src/moscot/problems/cross_modality/_mixins.py | 22 +++-- .../problems/cross_modality/_translation.py | 2 +- src/moscot/problems/generic/_generic.py | 8 +- src/moscot/problems/space/_mapping.py | 2 +- src/moscot/problems/space/_mixins.py | 59 +++++++------ src/moscot/problems/time/_mixins.py | 30 +++---- 10 files changed, 170 insertions(+), 65 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index fc762d233..be15f6cd2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -32,6 +32,10 @@ _validate_args_cell_transition, ) from moscot.base.problems.compound_problem import B, K +from moscot.base.problems.problem import ( + AbstractPushPullAdata, + AbstractSolutionsProblems, +) from moscot.plotting._utils import set_plotting_vars from moscot.utils.data import transcription_factors from moscot.utils.subset_policy import SubsetPolicy @@ -39,7 +43,7 @@ __all__ = ["AnalysisMixin"] -class AnalysisMixin(Generic[K, B]): +class AnalysisMixin(Generic[K, B], AbstractPushPullAdata, AbstractSolutionsProblems): """Base Analysis Mixin.""" def __init__(self, *args: Any, **kwargs: Any): @@ -48,7 +52,7 @@ def __init__(self, *args: Any, **kwargs: Any): def _cell_transition( self, source: K, - target: K, + target: Optional[K], source_groups: Str_Dict_t, target_groups: Str_Dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", @@ -130,7 +134,7 @@ def _cell_transition_online( self, key: Optional[str], source: K, - target: K, + target: Optional[K], source_groups: Str_Dict_t, target_groups: Str_Dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic @@ -189,7 +193,7 @@ def _cell_transition_online( split_mass=False, **move_op_const_kwargs, ) - tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] + tm = self._annotation_aggregation_transition( annotations_1=source_annotations_verified if forward else target_annotations_verified, annotations_2=target_annotations_verified if forward else source_annotations_verified, df=df_to, @@ -203,7 +207,7 @@ def _cell_transition_online( split_mass=True, **move_op_const_kwargs, ) - tm = self._cell_aggregation_transition( # type: ignore[attr-defined] + tm = self._cell_aggregation_transition( df_from=df_from, df_to=df_to, annotations=target_annotations_verified if forward else source_annotations_verified, @@ -623,7 +627,7 @@ def compute_entropy( split_mass=True, key_added=None, ) - df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator] + df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) if key_added is not None: self.adata.obs[key_added] = df return df if key_added is None else None diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index cec80c8a8..f2239e499 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -8,13 +8,13 @@ from moscot._logging import logger from moscot._types import ArrayLike -from moscot.base.problems.problem import OTProblem +from moscot.base.problems.problem import AbstractAdataAccess, OTProblem from moscot.utils.data import apoptosis_markers, proliferation_markers __all__ = ["BirthDeathProblem", "BirthDeathMixin"] -class BirthDeathMixin: +class BirthDeathMixin(AbstractAdataAccess): """Mixin class used to estimate cell proliferation and apoptosis. Parameters @@ -88,7 +88,7 @@ def score_genes_for_marginals( "At least one of `gene_set_proliferation` or `gene_set_apoptosis` must be provided to score genes." ) - return self # type: ignore[return-value] + return self @property def proliferation_key(self) -> Optional[str]: diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 6070b34c7..bd168201c 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -47,7 +47,7 @@ ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]] -class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]): +class BaseCompoundProblem(BaseProblem, Generic[K, B], abc.ABC): """Base class for all biological problems. This class translates a biological problem to multiple :term:`OT` problems. diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 708b35d44..1c71c0beb 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -1,6 +1,7 @@ import abc import pathlib import types +from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, @@ -222,6 +223,93 @@ def problem_kind(self) -> ProblemKind_t: return self._problem_kind +class AbstractAdataAccess(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def adata(self) -> AnnData: + """Annotated data object.""" + pass + + +class AbstractPushPullAdata(AbstractAdataAccess): + + @abstractmethod + def pull( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + pass + + @abstractmethod + def push( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + pass + + @abstractmethod + def _apply( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + pass + + +class AbstractSolutionsProblems(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def solutions(self) -> Any: + """Solutions.""" + pass + + @property + @abstractmethod + def problems(self) -> Any: + """Problems.""" + pass + + @property + @abstractmethod + def _policy(self) -> Any: + """Subset policy.""" + pass + + +class AbstractSrcTgt(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def adata_src(self) -> AnnData: + """Annotated data object.""" + pass + + @property + @abstractmethod + def adata_tgt(self) -> AnnData: + """Annotated data object.""" + pass + + +class AbstractSpSc(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def adata_sp(self) -> AnnData: + """Annotated data object.""" + pass + + @property + @abstractmethod + def adata_sc(self) -> AnnData: + """Annotated data object.""" + pass + + class OTProblem(BaseProblem): """Base class for all :term:`OT` problems. diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index d436fe1e6..044b16fce 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -9,11 +9,12 @@ from moscot._types import ArrayLike, Str_Dict_t from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K +from moscot.base.problems.problem import AbstractSrcTgt __all__ = ["CrossModalityTranslationMixin"] -class CrossModalityTranslationMixin(AnalysisMixin[K, B]): +class CrossModalityTranslationMixin(AnalysisMixin[K, B], AbstractSrcTgt): """Cross modality translation analysis mixin class.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -22,7 +23,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._tgt_attr: Optional[Dict[str, Any]] = None self._batch_key: Optional[str] = None - def translate( # type: ignore[misc] + def translate( self, source: K, target: K, @@ -92,7 +93,7 @@ def _get_features( adata_src = self.adata_src if self.batch_key is None else prob.adata_src return prob.push(_get_features(adata_src, attr=src_attr), **kwargs) - def cell_transition( # type: ignore[misc] + def cell_transition( self, source: K, target: Optional[K] = None, @@ -171,16 +172,16 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( # type: ignore[misc] + def annotation_mapping( self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: str = "src", - target: str = "tgt", + source: K = "src", + target: K = "tgt", batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -208,6 +209,9 @@ def annotation_mapping( # type: ignore[misc] If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + scale_by_marginals + todo + Returns ------- @@ -223,7 +227,7 @@ def annotation_mapping( # type: ignore[misc] other_adata=self.adata_tgt, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + scale_by_marginals=True, ) @property @@ -233,6 +237,6 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] + if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 5c8f9081c..390a01e11 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -22,7 +22,7 @@ __all__ = ["TranslationProblem"] -class TranslationProblem(CrossModalityTranslationMixin[K, OTProblem], CompoundProblem[K, OTProblem]): +class TranslationProblem(CompoundProblem[K, OTProblem], CrossModalityTranslationMixin[K, OTProblem]): """Class for integrating single-cell multi-omics data, based on :cite:`demetci-scot:22`. Parameters diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 8d727d0d0..720dc4521 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -128,7 +128,7 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'linear'``. """ - self.batch_key = key # type: ignore[misc] + self.batch_key = key xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs) xy, _, _ = handle_cost( xy=xy, @@ -366,7 +366,7 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'quadratic'``. """ - self.batch_key = key # type: ignore[misc] + self.batch_key = key x = set_quad_defaults(x_attr) if x_callback is None else {} y = set_quad_defaults(y_attr) if y_callback is None else {} @@ -630,7 +630,7 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'quadratic'``. """ - self.batch_key = key # type: ignore[misc] + self.batch_key = key x = set_quad_defaults(x_attr) if x_callback is None else {} y = set_quad_defaults(y_attr) if y_callback is None else {} xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs) @@ -792,7 +792,7 @@ def prepare( **kwargs: Any, ) -> "GENOTLinProblem[K, B]": """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" - self.batch_key = key # type:ignore[misc] + self.batch_key = key xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) conditions = handle_conditional_attr(conditional_attr) xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index abe440e28..2994de5cd 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -359,7 +359,7 @@ def filtered_vars(self) -> Optional[Sequence[str]]: @filtered_vars.setter def filtered_vars(self, value: Optional[Sequence[str]]) -> None: - self._filtered_vars = self._filter_vars(var_names=value) # type: ignore[misc] + self._filtered_vars = self._filter_vars(var_names=value) @property def _base_problem_type(self) -> Type[B]: diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 1938ae06e..d0590b47b 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -30,12 +30,13 @@ from moscot._types import ArrayLike, Device_t, Str_Dict_t from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K +from moscot.base.problems.problem import AbstractSpSc, AbstractSrcTgt from moscot.utils.subset_policy import StarPolicy __all__ = ["SpatialAlignmentMixin", "SpatialMappingMixin"] -class SpatialAlignmentMixin(AnalysisMixin[K, B]): +class SpatialAlignmentMixin(AnalysisMixin[K, B], AbstractSrcTgt): """Spatial alignment mixin class.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -43,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._spatial_key: Optional[str] = None self._batch_key: Optional[str] = None - def _interpolate_scheme( # type: ignore[misc] + def _interpolate_scheme( self, reference: K, mode: Literal["warp", "affine"], @@ -87,7 +88,7 @@ def _interpolate_scheme( # type: ignore[misc] # TODO(michalk8): always return the metadata? return transport_maps, (transport_metadata if mode == "affine" else None) - def align( # type: ignore[misc] + def align( self, reference: Optional[K] = None, mode: Literal["warp", "affine"] = "warp", @@ -147,7 +148,7 @@ def align( # type: ignore[misc] self.adata.uns.setdefault(key_added, {}) self.adata.uns[key_added]["alignment_metadata"] = aligned_metadata # noqa: RET503 - def cell_transition( # type: ignore[misc] + def cell_transition( self, source: K, target: K, @@ -225,16 +226,17 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( # type: ignore[misc] + def annotation_mapping( self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: str = "src", - target: str = "tgt", + source: K = "src", + target: K = "tgt", batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + other_adata: Optional[AnnData] = None, + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -262,6 +264,10 @@ def annotation_mapping( # type: ignore[misc] If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + other_adata + The other :class:`~anndata.AnnData` object to use for the cell transition. + scale_by_marginals + If :obj:`True`, scale the transition matrix by marginals. Returns ------- @@ -276,7 +282,8 @@ def annotation_mapping( # type: ignore[misc] forward=forward, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + other_adata=other_adata, + scale_by_marginals=scale_by_marginals, ) @property @@ -285,7 +292,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self, key: Optional[str]) -> None: # type: ignore[misc] + def spatial_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key @@ -297,11 +304,11 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] + if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key - def _subset_spatial( # type: ignore[misc] + def _subset_spatial( self, k: K, spatial_key: str, @@ -310,7 +317,7 @@ def _subset_spatial( # type: ignore[misc] return self.adata[mask].obsm[spatial_key].astype(float, copy=True) -class SpatialMappingMixin(AnalysisMixin[K, B]): +class SpatialMappingMixin(AnalysisMixin[K, B], AbstractSpSc): """Spatial mapping analysis mixin class.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -318,7 +325,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None self._spatial_key: Optional[str] = None - def _filter_vars( # type: ignore[misc] + def _filter_vars( self, var_names: Optional[Sequence[str]] = None, ) -> Optional[List[str]]: @@ -340,7 +347,7 @@ def _filter_vars( # type: ignore[misc] raise ValueError("Some variable are missing in the single-cell or the spatial `AnnData`.") - def correlate( # type: ignore[misc] + def correlate( self, var_names: Optional[Sequence[str]] = None, corr_method: Literal["pearson", "spearman"] = "pearson", @@ -439,7 +446,7 @@ def correlate( # type: ignore[misc] return corrs - def impute( # type: ignore[misc] + def impute( self, var_names: Optional[Sequence[str]] = None, device: Optional[Device_t] = None, @@ -494,7 +501,7 @@ def impute( # type: ignore[misc] return adata_pred - def spatial_correspondence( # type: ignore[misc] + def spatial_correspondence( self, interval: Union[int, ArrayLike] = 10, max_dist: Optional[int] = None, @@ -555,7 +562,7 @@ def _get_features( res[self.batch_key] = res[self.batch_key].astype("category") # type: ignore[call-overload] return res - def cell_transition( # type: ignore[misc] + def cell_transition( self, source: K, target: Optional[K] = None, @@ -620,7 +627,7 @@ def cell_transition( # type: ignore[misc] return self._cell_transition( key=self.batch_key, source=source, - target=target, + target=target or None, source_groups=source_groups, target_groups=target_groups, forward=forward, @@ -632,16 +639,16 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( # type: ignore[misc] + def annotation_mapping( self, mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, - target: Union[K, str] = "tgt", + target: K = "tgt", forward: bool = False, batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -669,6 +676,8 @@ def annotation_mapping( # type: ignore[misc] If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + scale_by_marginals + todo Returns ------- @@ -684,7 +693,7 @@ def annotation_mapping( # type: ignore[misc] other_adata=self.adata_sc, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + scale_by_marginals=scale_by_marginals, ) @property @@ -694,7 +703,7 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] + if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key @@ -704,7 +713,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self, key: Optional[str]) -> None: # type: ignore[misc] + def spatial_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 68b2900b3..8c7275bd7 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -14,13 +14,14 @@ from moscot._types import ArrayLike, Str_Dict_t from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import ApplyOutput_t, B, K +from moscot.base.problems.problem import AbstractSolutionsProblems from moscot.plotting._utils import set_plotting_vars from moscot.utils.tagged_array import Tag __all__ = ["TemporalMixin"] -class TemporalMixin(AnalysisMixin[K, B]): +class TemporalMixin(AnalysisMixin[K, B], AbstractSolutionsProblems): """Analysis Mixin for all problems involving a temporal dimension.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -111,7 +112,7 @@ def annotation_mapping( target: K, batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -139,6 +140,8 @@ def annotation_mapping( If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + scale_by_marginals + todo Returns ------- @@ -154,7 +157,7 @@ def annotation_mapping( other_adata=None, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + scale_by_marginals=scale_by_marginals, ) def sankey( @@ -220,7 +223,7 @@ def sankey( :attr:`uns['moscot_results']['sankey']['{key_added}'] ` """ tuples = self._policy.plan(start=source, end=target) - cell_transitions = [] + cell_transitions: list[Any] = [] for src, tgt in tuples: cell_transitions.append( self.cell_transition( @@ -462,7 +465,7 @@ def cell_costs_source(self) -> Optional[pd.DataFrame]: # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.asarray(problem.solution.potentials[0]), # type: ignore[union-attr,index] + np.asarray(problem.solution.potentials[0]), index=problem.adata_src.obs_names, columns=cols, ) @@ -487,7 +490,7 @@ def cell_costs_target(self) -> Optional[pd.DataFrame]: # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.array(problem.solution.potentials[1]), # type: ignore[union-attr,index] + np.array(problem.solution.potentials[1]), index=problem.adata_tgt.obs_names, columns=cols, ) @@ -509,14 +512,11 @@ def _get_data( ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: # TODO: use .items() for src, tgt in self.problems: - tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr] + tag = self.problems[src, tgt].xy.tag if tag != Tag.POINT_CLOUD: - raise ValueError( - f"Expected `tag={Tag.POINT_CLOUD}`, " - f"found `tag={self.problems[src, tgt].xy.tag}`." # type: ignore[union-attr] - ) + raise ValueError(f"Expected `tag={Tag.POINT_CLOUD}`, " f"found `tag={self.problems[src, tgt].xy.tag}`.") if src == source: - source_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] + source_data = self.problems[src, tgt].xy.data_src if only_start: return source_data, self.problems[src, tgt].adata_src # TODO(michalk8): posterior marginals @@ -527,19 +527,19 @@ def _get_data( raise ValueError(f"No data found for `{source}` time point.") for src, tgt in self.problems: if src == intermediate: - intermediate_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] + intermediate_data = self.problems[src, tgt].xy.data_src intermediate_adata = self.problems[src, tgt].adata_src break else: raise ValueError(f"No data found for `{intermediate}` time point.") for src, tgt in self.problems: if tgt == target: - target_data = self.problems[src, tgt].xy.data_tgt # type: ignore[union-attr] + target_data = self.problems[src, tgt].xy.data_tgt break else: raise ValueError(f"No data found for `{target}` time point.") - return ( # type:ignore[return-value] + return ( source_data, growth_rates_source, intermediate_data, From 90fa16d0954c504e7ddae846ee4cf1e7f5e215f1 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 25 Nov 2024 15:49:51 +0100 Subject: [PATCH 4/8] fix some bug --- src/moscot/problems/space/_mixins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index d0590b47b..83e69bad4 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -30,13 +30,13 @@ from moscot._types import ArrayLike, Device_t, Str_Dict_t from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K -from moscot.base.problems.problem import AbstractSpSc, AbstractSrcTgt +from moscot.base.problems.problem import AbstractSpSc from moscot.utils.subset_policy import StarPolicy __all__ = ["SpatialAlignmentMixin", "SpatialMappingMixin"] -class SpatialAlignmentMixin(AnalysisMixin[K, B], AbstractSrcTgt): +class SpatialAlignmentMixin(AnalysisMixin[K, B]): """Spatial alignment mixin class.""" def __init__(self, *args: Any, **kwargs: Any) -> None: From f84f8be91e09a174edfbd1d4985e2dd575c0eafc Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 25 Nov 2024 16:06:59 +0100 Subject: [PATCH 5/8] rewrite the TODOs to clarify --- src/moscot/problems/cross_modality/_mixins.py | 4 ++-- src/moscot/problems/space/_mixins.py | 4 ++-- src/moscot/problems/time/_mixins.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 044b16fce..8c6024895 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -210,7 +210,7 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. scale_by_marginals - todo + TODO Returns @@ -227,7 +227,7 @@ def annotation_mapping( other_adata=self.adata_tgt, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - scale_by_marginals=True, + scale_by_marginals=scale_by_marginals, ) @property diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 83e69bad4..cfa0574c1 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -265,9 +265,9 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. other_adata - The other :class:`~anndata.AnnData` object to use for the cell transition. + TODO scale_by_marginals - If :obj:`True`, scale the transition matrix by marginals. + TODO Returns ------- diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 8c7275bd7..90d6cda92 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -141,7 +141,7 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. scale_by_marginals - todo + TODO Returns ------- From baa03624be269ae9504d425b5497e456732bf7d3 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 25 Nov 2024 16:17:56 +0100 Subject: [PATCH 6/8] Turns out GENOTLinProblems doesn't need to inherit the mixins --- src/moscot/problems/generic/_generic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 720dc4521..d18ab7948 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -776,7 +776,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class GENOTLinProblem(CondOTProblem, GenericAnalysisMixin[K, B]): +class GENOTLinProblem(CondOTProblem): """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" def prepare( From 0fcd620dba413fd28f0a1d1f046c720f13045174 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 25 Nov 2024 16:20:14 +0100 Subject: [PATCH 7/8] fix generic error w genot --- src/moscot/problems/generic/_generic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index d18ab7948..fcd3d8b2e 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -790,7 +790,7 @@ def prepare( cost: OttCostFn_t = "sq_euclidean", cost_kwargs: CostKwargs_t = types.MappingProxyType({}), **kwargs: Any, - ) -> "GENOTLinProblem[K, B]": + ) -> "GENOTLinProblem": """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" self.batch_key = key xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) @@ -816,7 +816,7 @@ def solve( valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), train_size: float = 1.0, **kwargs: Any, - ) -> "GENOTLinProblem[K, B]": + ) -> "GENOTLinProblem": """Solve.""" return super().solve( batch_size=batch_size, From 901507dbb4c2615d1980efa7aae61f3357e3dbab Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 25 Nov 2024 16:47:56 +0100 Subject: [PATCH 8/8] do the todos --- src/moscot/problems/cross_modality/_mixins.py | 2 +- src/moscot/problems/space/_mixins.py | 6 +++--- src/moscot/problems/time/_mixins.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 8c6024895..051e60ff4 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -210,7 +210,7 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. scale_by_marginals - TODO + Whether to scale by the source/target :term:`marginals`. Returns diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index cfa0574c1..6181de0cf 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -265,9 +265,9 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. other_adata - TODO + Other adata object to use for the cell transitions. scale_by_marginals - TODO + Whether to scale by the source/target :term:`marginals`. Returns ------- @@ -677,7 +677,7 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. scale_by_marginals - todo + Whether to scale by the source/target :term:`marginals`. Returns ------- diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 90d6cda92..c2537f6b4 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -141,7 +141,7 @@ def annotation_mapping( cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. scale_by_marginals - TODO + Whether to scale by the source/target :term:`marginals`. Returns -------