diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a373542fc..63117b699 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ default_stages: minimum_pre_commit_version: 3.0.0 repos: - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: [numpy>=1.25.0] @@ -42,12 +42,12 @@ repos: - id: check-yaml - id: check-toml - repo: https://github.com/asottile/pyupgrade - rev: v3.18.0 + rev: v3.19.0 hooks: - id: pyupgrade args: [--py3-plus, --py38-plus, --keep-runtime-typing] - repo: https://github.com/asottile/blacken-docs - rev: 1.19.0 + rev: 1.19.1 hooks: - id: blacken-docs additional_dependencies: [black==23.1.0] @@ -63,7 +63,7 @@ repos: - id: doc8 - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/pyproject.toml b/pyproject.toml index 2fd66c9a6..9fdbc3fa3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "scanpy>=1.9.3", "wrapt>=1.13.2", "docrep>=0.3.2", - "ott-jax[neural]>=0.4.6", + "ott-jax[neural]>=0.4.6,<=0.4.8", "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0", diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 2a5ba2029..be15f6cd2 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -11,7 +11,6 @@ Literal, Mapping, Optional, - Protocol, Sequence, Union, ) @@ -21,11 +20,9 @@ from scipy.sparse.linalg import LinearOperator import scanpy as sc -from anndata import AnnData from moscot import _constants from moscot._types import ArrayLike, Numeric_t, Str_Dict_t -from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.problems._utils import ( _check_argument_compatibility_cell_transition, _correlation_test, @@ -34,7 +31,11 @@ _validate_annotations, _validate_args_cell_transition, ) -from moscot.base.problems.compound_problem import ApplyOutput_t, B, K +from moscot.base.problems.compound_problem import B, K +from moscot.base.problems.problem import ( + AbstractPushPullAdata, + AbstractSolutionsProblems, +) from moscot.plotting._utils import set_plotting_vars from moscot.utils.data import transcription_factors from moscot.utils.subset_policy import SubsetPolicy @@ -42,96 +43,16 @@ __all__ = ["AnalysisMixin"] -class AnalysisMixinProtocol(Protocol[K, B]): - """Protocol class.""" - - adata: AnnData - _policy: SubsetPolicy[K] - solutions: dict[tuple[K, K], BaseDiscreteSolverOutput] - problems: dict[tuple[K, K], B] - - def _apply( - self, - data: Optional[Union[str, ArrayLike]] = None, - source: Optional[K] = None, - target: Optional[K] = None, - forward: bool = True, - return_all: bool = False, - scale_by_marginals: bool = False, - **kwargs: Any, - ) -> ApplyOutput_t[K]: ... - - def _interpolate_transport( - self: AnalysisMixinProtocol[K, B], - path: Sequence[tuple[K, K]], - scale_by_marginals: bool = True, - ) -> LinearOperator: ... - - def _flatten( - self: AnalysisMixinProtocol[K, B], - data: dict[K, ArrayLike], - *, - key: Optional[str], - ) -> ArrayLike: ... - - def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: - """Push distribution.""" - ... - - def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: - """Pull distribution.""" - ... - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - aggregation_mode: Literal["annotation", "cell"] = "annotation", - key_added: Optional[str] = _constants.CELL_TRANSITION, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _cell_transition_online( - self: AnalysisMixinProtocol[K, B], - key: Optional[str], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - other_key: Optional[str] = None, - other_adata: Optional[str] = None, - batch_size: Optional[int] = None, - normalize: bool = True, - ) -> pd.DataFrame: ... - - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - mapping_mode: Literal["sum", "max"], - annotation_label: str, - forward: bool, - source: K, - target: K, - key: str | None = None, - other_adata: Optional[str] = None, - scale_by_marginals: bool = True, - cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - ) -> pd.DataFrame: ... - - -class AnalysisMixin(Generic[K, B]): +class AnalysisMixin(Generic[K, B], AbstractPushPullAdata, AbstractSolutionsProblems): """Base Analysis Mixin.""" def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) def _cell_transition( - self: AnalysisMixinProtocol[K, B], + self, source: K, - target: K, + target: Optional[K], source_groups: Str_Dict_t, target_groups: Str_Dict_t, aggregation_mode: Literal["annotation", "cell"] = "annotation", @@ -178,7 +99,7 @@ def _cell_transition( return tm def _annotation_aggregation_transition( - self: AnalysisMixinProtocol[K, B], + self, annotations_1: list[Any], annotations_2: list[Any], df: pd.DataFrame, @@ -210,10 +131,10 @@ def _annotation_aggregation_transition( ) def _cell_transition_online( - self: AnalysisMixinProtocol[K, B], + self, key: Optional[str], source: K, - target: K, + target: Optional[K], source_groups: Str_Dict_t, target_groups: Str_Dict_t, forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic @@ -272,7 +193,7 @@ def _cell_transition_online( split_mass=False, **move_op_const_kwargs, ) - tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] + tm = self._annotation_aggregation_transition( annotations_1=source_annotations_verified if forward else target_annotations_verified, annotations_2=target_annotations_verified if forward else source_annotations_verified, df=df_to, @@ -286,7 +207,7 @@ def _cell_transition_online( split_mass=True, **move_op_const_kwargs, ) - tm = self._cell_aggregation_transition( # type: ignore[attr-defined] + tm = self._cell_aggregation_transition( df_from=df_from, df_to=df_to, annotations=target_annotations_verified if forward else source_annotations_verified, @@ -309,7 +230,7 @@ def _cell_transition_online( ) def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, @@ -394,7 +315,7 @@ def _annotation_mapping( raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.") def _sample_from_tmap( - self: AnalysisMixinProtocol[K, B], + self, source: K, target: K, n_samples: int, @@ -472,7 +393,7 @@ def _sample_from_tmap( return rows, all_cols_sampled # type: ignore[return-value] def _interpolate_transport( - self: AnalysisMixinProtocol[K, B], + self, # TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key) path: Sequence[tuple[K, K]], scale_by_marginals: bool = True, @@ -485,7 +406,7 @@ def _interpolate_transport( fst, *rest = path return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals) - def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: + def _flatten(self, data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike: tmp = np.full(len(self.adata), np.nan) for k, v in data.items(): mask = self.adata.obs[key] == k @@ -493,7 +414,7 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key return tmp def _cell_aggregation_transition( - self: AnalysisMixinProtocol[K, B], + self, df_from: pd.DataFrame, df_to: pd.DataFrame, annotations: list[Any], @@ -540,7 +461,7 @@ def _cell_aggregation_transition( # adapted from: # https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392 def compute_feature_correlation( - self: AnalysisMixinProtocol[K, B], + self, obs_key: str, corr_method: Literal["pearson", "spearman"] = "pearson", significance_method: Literal["fisher", "perm_test"] = "fisher", @@ -649,7 +570,7 @@ def compute_feature_correlation( ) def compute_entropy( - self: AnalysisMixinProtocol[K, B], + self, source: K, target: K, forward: bool = True, @@ -706,7 +627,7 @@ def compute_entropy( split_mass=True, key_added=None, ) - df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator] + df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) if key_added is not None: self.adata.obs[key_added] = df return df if key_added is None else None diff --git a/src/moscot/base/problems/birth_death.py b/src/moscot/base/problems/birth_death.py index 27d667b49..f2239e499 100644 --- a/src/moscot/base/problems/birth_death.py +++ b/src/moscot/base/problems/birth_death.py @@ -1,14 +1,5 @@ from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Literal, - Optional, - Protocol, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union import numpy as np @@ -17,39 +8,13 @@ from moscot._logging import logger from moscot._types import ArrayLike -from moscot.base.problems.problem import OTProblem +from moscot.base.problems.problem import AbstractAdataAccess, OTProblem from moscot.utils.data import apoptosis_markers, proliferation_markers __all__ = ["BirthDeathProblem", "BirthDeathMixin"] -class BirthDeathProtocol(Protocol): # noqa: D101 - adata: AnnData - proliferation_key: Optional[str] - apoptosis_key: Optional[str] - _proliferation_key: Optional[str] - _apoptosis_key: Optional[str] - _scaling: float - _prior_growth: Optional[ArrayLike] - - def score_genes_for_marginals( # noqa: D102 - self: "BirthDeathProtocol", - gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, - gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, - proliferation_key: str = "proliferation", - apoptosis_key: str = "apoptosis", - **kwargs: Any, - ) -> "BirthDeathProtocol": ... - - -class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101 - delta: float - adata_tgt: AnnData - a: Optional[ArrayLike] - b: Optional[ArrayLike] - - -class BirthDeathMixin: +class BirthDeathMixin(AbstractAdataAccess): """Mixin class used to estimate cell proliferation and apoptosis. Parameters @@ -68,7 +33,7 @@ def __init__(self, *args: Any, **kwargs: Any): self._prior_growth: Optional[ArrayLike] = None def score_genes_for_marginals( - self, # type: BirthDeathProblemProtocol + self, gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None, proliferation_key: str = "proliferation", @@ -123,7 +88,7 @@ def score_genes_for_marginals( "At least one of `gene_set_proliferation` or `gene_set_apoptosis` must be provided to score genes." ) - return self # type: ignore[return-value] + return self @property def proliferation_key(self) -> Optional[str]: @@ -131,7 +96,7 @@ def proliferation_key(self) -> Optional[str]: return self._proliferation_key @proliferation_key.setter - def proliferation_key(self: BirthDeathProtocol, key: Optional[str]) -> None: + def proliferation_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find proliferation data in `adata.obs[{key!r}]`.") self._proliferation_key = key @@ -142,7 +107,7 @@ def apoptosis_key(self) -> Optional[str]: return self._apoptosis_key @apoptosis_key.setter - def apoptosis_key(self: BirthDeathProtocol, key: Optional[str]) -> None: + def apoptosis_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find apoptosis data in `adata.obs[{key!r}]`.") self._apoptosis_key = key @@ -161,7 +126,7 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem): """ # noqa: D205 def estimate_marginals( - self, # type: BirthDeathProblemProtocol + self, adata: AnnData, source: bool, proliferation_key: Optional[str] = None, diff --git a/src/moscot/base/problems/compound_problem.py b/src/moscot/base/problems/compound_problem.py index 6070b34c7..bd168201c 100644 --- a/src/moscot/base/problems/compound_problem.py +++ b/src/moscot/base/problems/compound_problem.py @@ -47,7 +47,7 @@ ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]] -class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]): +class BaseCompoundProblem(BaseProblem, Generic[K, B], abc.ABC): """Base class for all biological problems. This class translates a biological problem to multiple :term:`OT` problems. diff --git a/src/moscot/base/problems/manager.py b/src/moscot/base/problems/manager.py index bdf4c59e9..432dc8b9e 100644 --- a/src/moscot/base/problems/manager.py +++ b/src/moscot/base/problems/manager.py @@ -72,7 +72,7 @@ def add_problem( if isinstance(self._compound_problem, CompoundProblem) else OTProblem ) - if not isinstance(problem, clazz): # type:ignore[arg-type] + if not isinstance(problem, clazz): raise TypeError(f"Expected problem of type `{OTProblem}`, found `{type(problem)}`.") self.problems[key] = problem diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 83e1d13d7..8cebbb639 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -1,6 +1,7 @@ import abc import pathlib import types +from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, @@ -222,6 +223,93 @@ def problem_kind(self) -> ProblemKind_t: return self._problem_kind +class AbstractAdataAccess(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def adata(self) -> AnnData: + """Annotated data object.""" + pass + + +class AbstractPushPullAdata(AbstractAdataAccess): + + @abstractmethod + def pull( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + pass + + @abstractmethod + def push( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + pass + + @abstractmethod + def _apply( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + pass + + +class AbstractSolutionsProblems(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def solutions(self) -> Any: + """Solutions.""" + pass + + @property + @abstractmethod + def problems(self) -> Any: + """Problems.""" + pass + + @property + @abstractmethod + def _policy(self) -> Any: + """Subset policy.""" + pass + + +class AbstractSrcTgt(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def adata_src(self) -> AnnData: + """Annotated data object.""" + pass + + @property + @abstractmethod + def adata_tgt(self) -> AnnData: + """Annotated data object.""" + pass + + +class AbstractSpSc(ABC, metaclass=CombinedMeta): + + @property + @abstractmethod + def adata_sp(self) -> AnnData: + """Annotated data object.""" + pass + + @property + @abstractmethod + def adata_sc(self) -> AnnData: + """Annotated data object.""" + pass + + class OTProblem(BaseProblem): """Base class for all :term:`OT` problems. diff --git a/src/moscot/datasets.py b/src/moscot/datasets.py index fb34a573a..b0101bd32 100644 --- a/src/moscot/datasets.py +++ b/src/moscot/datasets.py @@ -18,14 +18,15 @@ from moscot._types import PathLike __all__ = [ - "mosta", - "hspc", - "drosophila", - "c_elegans", - "zebrafish", "bone_marrow", + "c_elegans", + "drosophila", + "hspc", + "mosta", + "pancreas_multiome", "sim_align", "simulate_data", + "zebrafish", ] @@ -306,7 +307,7 @@ def pancreas_multiome( return _load_dataset_from_url( path, file_type="h5ad", - backup_url="https://figshare.com/ndownloader/files/48785320", + backup_url="https://figshare.com/ndownloader/files/49725087", expected_shape=(22604, 20242), force_download=force_download, **kwargs, diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index 5a3967e1f..051e60ff4 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -7,27 +7,14 @@ from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K +from moscot.base.problems.problem import AbstractSrcTgt __all__ = ["CrossModalityTranslationMixin"] -class CrossModalityTranslationMixinProtocol(AnalysisMixinProtocol[K, B]): - """Protocol class.""" - - adata_src: AnnData - adata_tgt: AnnData - _src_attr: Optional[Dict[str, Any]] - _tgt_attr: Optional[Dict[str, Any]] - batch_key: Optional[str] - - def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - -class CrossModalityTranslationMixin(AnalysisMixin[K, B]): +class CrossModalityTranslationMixin(AnalysisMixin[K, B], AbstractSrcTgt): """Cross modality translation analysis mixin class.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -36,8 +23,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._tgt_attr: Optional[Dict[str, Any]] = None self._batch_key: Optional[str] = None - def translate( # type: ignore[misc] - self: CrossModalityTranslationMixinProtocol[K, B], + def translate( + self, source: K, target: K, forward: bool = True, @@ -106,8 +93,8 @@ def _get_features( adata_src = self.adata_src if self.batch_key is None else prob.adata_src return prob.push(_get_features(adata_src, attr=src_attr), **kwargs) - def cell_transition( # type: ignore[misc] - self: CrossModalityTranslationMixinProtocol[K, B], + def cell_transition( + self, source: K, target: Optional[K] = None, source_groups: Optional[Str_Dict_t] = None, @@ -185,16 +172,16 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( # type: ignore[misc] - self: CrossModalityTranslationMixinProtocol[K, B], + def annotation_mapping( + self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: str = "src", - target: str = "tgt", + source: K = "src", + target: K = "tgt", batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -222,6 +209,9 @@ def annotation_mapping( # type: ignore[misc] If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + scale_by_marginals + Whether to scale by the source/target :term:`marginals`. + Returns ------- @@ -237,7 +227,7 @@ def annotation_mapping( # type: ignore[misc] other_adata=self.adata_tgt, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + scale_by_marginals=scale_by_marginals, ) @property @@ -247,6 +237,6 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] + if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key diff --git a/src/moscot/problems/cross_modality/_translation.py b/src/moscot/problems/cross_modality/_translation.py index 5c8f9081c..390a01e11 100644 --- a/src/moscot/problems/cross_modality/_translation.py +++ b/src/moscot/problems/cross_modality/_translation.py @@ -22,7 +22,7 @@ __all__ = ["TranslationProblem"] -class TranslationProblem(CrossModalityTranslationMixin[K, OTProblem], CompoundProblem[K, OTProblem]): +class TranslationProblem(CompoundProblem[K, OTProblem], CrossModalityTranslationMixin[K, OTProblem]): """Class for integrating single-cell multi-omics data, based on :cite:`demetci-scot:22`. Parameters diff --git a/src/moscot/problems/generic/_generic.py b/src/moscot/problems/generic/_generic.py index 8d727d0d0..fcd3d8b2e 100644 --- a/src/moscot/problems/generic/_generic.py +++ b/src/moscot/problems/generic/_generic.py @@ -128,7 +128,7 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'linear'``. """ - self.batch_key = key # type: ignore[misc] + self.batch_key = key xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs) xy, _, _ = handle_cost( xy=xy, @@ -366,7 +366,7 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'quadratic'``. """ - self.batch_key = key # type: ignore[misc] + self.batch_key = key x = set_quad_defaults(x_attr) if x_callback is None else {} y = set_quad_defaults(y_attr) if y_callback is None else {} @@ -630,7 +630,7 @@ def prepare( - :attr:`stage` - set to ``'prepared'``. - :attr:`problem_kind` - set to ``'quadratic'``. """ - self.batch_key = key # type: ignore[misc] + self.batch_key = key x = set_quad_defaults(x_attr) if x_callback is None else {} y = set_quad_defaults(y_attr) if y_callback is None else {} xy, xy_callback, xy_callback_kwargs = handle_joint_attr(joint_attr, xy_callback, xy_callback_kwargs) @@ -776,7 +776,7 @@ def _valid_policies(self) -> Tuple[Policy_t, ...]: return _constants.SEQUENTIAL, _constants.EXPLICIT, _constants.STAR # type: ignore[return-value] -class GENOTLinProblem(CondOTProblem, GenericAnalysisMixin[K, B]): +class GENOTLinProblem(CondOTProblem): """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" def prepare( @@ -790,9 +790,9 @@ def prepare( cost: OttCostFn_t = "sq_euclidean", cost_kwargs: CostKwargs_t = types.MappingProxyType({}), **kwargs: Any, - ) -> "GENOTLinProblem[K, B]": + ) -> "GENOTLinProblem": """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" - self.batch_key = key # type:ignore[misc] + self.batch_key = key xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) conditions = handle_conditional_attr(conditional_attr) xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) @@ -816,7 +816,7 @@ def solve( valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), train_size: float = 1.0, **kwargs: Any, - ) -> "GENOTLinProblem[K, B]": + ) -> "GENOTLinProblem": """Solve.""" return super().solve( batch_size=batch_size, diff --git a/src/moscot/problems/generic/_mixins.py b/src/moscot/problems/generic/_mixins.py index ed74ebffd..a2a90a0a8 100644 --- a/src/moscot/problems/generic/_mixins.py +++ b/src/moscot/problems/generic/_mixins.py @@ -1,32 +1,16 @@ -from typing import TYPE_CHECKING, Any, List, Literal, Optional, Protocol, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union import pandas as pd -from anndata import AnnData - from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import ApplyOutput_t, B, K from moscot.plotting._utils import set_plotting_vars __all__ = ["GenericAnalysisMixin"] -class GenericAnalysisMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): - """Protocol class.""" - - _batch_key: Optional[str] - batch_key: Optional[str] - adata: AnnData - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - class GenericAnalysisMixin(AnalysisMixin[K, B]): """Generic Analysis Mixin.""" @@ -35,7 +19,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None def cell_transition( - self: GenericAnalysisMixinProtocol[K, B], + self, source: K, target: K, source_groups: Optional[Str_Dict_t] = None, @@ -112,7 +96,7 @@ def cell_transition( ) def push( - self: GenericAnalysisMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -181,7 +165,7 @@ def push( return result def pull( - self: GenericAnalysisMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -247,12 +231,12 @@ def pull( return result @property - def batch_key(self: GenericAnalysisMixinProtocol[K, B]) -> Optional[str]: + def batch_key(self) -> Optional[str]: """Batch key in :attr:`~anndata.AnnData.obs`.""" return self._batch_key @batch_key.setter - def batch_key(self: GenericAnalysisMixinProtocol[K, B], key: Optional[str]) -> None: + def batch_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key diff --git a/src/moscot/problems/space/_mapping.py b/src/moscot/problems/space/_mapping.py index abe440e28..2994de5cd 100644 --- a/src/moscot/problems/space/_mapping.py +++ b/src/moscot/problems/space/_mapping.py @@ -359,7 +359,7 @@ def filtered_vars(self) -> Optional[Sequence[str]]: @filtered_vars.setter def filtered_vars(self, value: Optional[Sequence[str]]) -> None: - self._filtered_vars = self._filter_vars(var_names=value) # type: ignore[misc] + self._filtered_vars = self._filter_vars(var_names=value) @property def _base_problem_type(self) -> Type[B]: diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index d72c79c2b..6181de0cf 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -28,65 +28,14 @@ from moscot import _constants from moscot._logging import logger from moscot._types import ArrayLike, Device_t, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import B, K +from moscot.base.problems.problem import AbstractSpSc from moscot.utils.subset_policy import StarPolicy __all__ = ["SpatialAlignmentMixin", "SpatialMappingMixin"] -class SpatialAlignmentMixinProtocol(AnalysisMixinProtocol[K, B]): - """Protocol class.""" - - spatial_key: Optional[str] - _spatial_key: Optional[str] - batch_key: Optional[str] - - def _subset_spatial( # type:ignore[empty-body] - self: "SpatialAlignmentMixinProtocol[K, B]", - k: K, - spatial_key: str, - ) -> ArrayLike: ... - - def _interpolate_scheme( # type:ignore[empty-body] - self: "SpatialAlignmentMixinProtocol[K, B]", - reference: K, - mode: Literal["warp", "affine"], - spatial_key: str, - ) -> Tuple[Dict[K, ArrayLike], Optional[Dict[K, Optional[ArrayLike]]]]: ... - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - -class SpatialMappingMixinProtocol(AnalysisMixinProtocol[K, B]): - """Protocol class.""" - - adata_sc: AnnData - adata_sp: AnnData - batch_key: Optional[str] - spatial_key: Optional[str] - _spatial_key: Optional[str] - - def _filter_vars( - self: "SpatialMappingMixinProtocol[K, B]", - var_names: Optional[Sequence[str]] = None, - ) -> Optional[List[str]]: ... - - def _cell_transition(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - def _annotation_mapping(self: AnalysisMixinProtocol[K, B], *args: Any, **kwargs: Any) -> pd.DataFrame: ... - - class SpatialAlignmentMixin(AnalysisMixin[K, B]): """Spatial alignment mixin class.""" @@ -95,8 +44,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._spatial_key: Optional[str] = None self._batch_key: Optional[str] = None - def _interpolate_scheme( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + def _interpolate_scheme( + self, reference: K, mode: Literal["warp", "affine"], spatial_key: str, @@ -113,7 +62,7 @@ def _interpolate_scheme( # type: ignore[misc] # get the reference reference_ = [reference] if isinstance(reference, str) else reference full_steps = self._policy._graph - starts = set(itertools.chain.from_iterable(full_steps)) - set(reference_) # type: ignore[call-overload] + starts = set(itertools.chain.from_iterable(full_steps)) - set(reference_) # type: ignore[arg-type] if mode == "affine": _transport = _affine @@ -139,8 +88,8 @@ def _interpolate_scheme( # type: ignore[misc] # TODO(michalk8): always return the metadata? return transport_maps, (transport_metadata if mode == "affine" else None) - def align( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + def align( + self, reference: Optional[K] = None, mode: Literal["warp", "affine"] = "warp", spatial_key: Optional[str] = None, @@ -199,8 +148,8 @@ def align( # type: ignore[misc] self.adata.uns.setdefault(key_added, {}) self.adata.uns[key_added]["alignment_metadata"] = aligned_metadata # noqa: RET503 - def cell_transition( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + def cell_transition( + self, source: K, target: K, source_groups: Optional[Str_Dict_t] = None, @@ -277,16 +226,17 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + def annotation_mapping( + self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, - source: str = "src", - target: str = "tgt", + source: K = "src", + target: K = "tgt", batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + other_adata: Optional[AnnData] = None, + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -314,6 +264,10 @@ def annotation_mapping( # type: ignore[misc] If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + other_adata + Other adata object to use for the cell transitions. + scale_by_marginals + Whether to scale by the source/target :term:`marginals`. Returns ------- @@ -328,7 +282,8 @@ def annotation_mapping( # type: ignore[misc] forward=forward, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + other_adata=other_adata, + scale_by_marginals=scale_by_marginals, ) @property @@ -337,7 +292,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: # type: ignore[misc] + def spatial_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key @@ -349,12 +304,12 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] + if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key - def _subset_spatial( # type: ignore[misc] - self: SpatialAlignmentMixinProtocol[K, B], + def _subset_spatial( + self, k: K, spatial_key: str, ) -> ArrayLike: @@ -362,7 +317,7 @@ def _subset_spatial( # type: ignore[misc] return self.adata[mask].obsm[spatial_key].astype(float, copy=True) -class SpatialMappingMixin(AnalysisMixin[K, B]): +class SpatialMappingMixin(AnalysisMixin[K, B], AbstractSpSc): """Spatial mapping analysis mixin class.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -370,8 +325,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._batch_key: Optional[str] = None self._spatial_key: Optional[str] = None - def _filter_vars( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + def _filter_vars( + self, var_names: Optional[Sequence[str]] = None, ) -> Optional[List[str]]: """Filter variables for the linear term.""" @@ -392,8 +347,8 @@ def _filter_vars( # type: ignore[misc] raise ValueError("Some variable are missing in the single-cell or the spatial `AnnData`.") - def correlate( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + def correlate( + self, var_names: Optional[Sequence[str]] = None, corr_method: Literal["pearson", "spearman"] = "pearson", device: Optional[Device_t] = None, @@ -491,8 +446,8 @@ def correlate( # type: ignore[misc] return corrs - def impute( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + def impute( + self, var_names: Optional[Sequence[str]] = None, device: Optional[Device_t] = None, batch_size: Optional[int] = None, @@ -546,8 +501,8 @@ def impute( # type: ignore[misc] return adata_pred - def spatial_correspondence( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + def spatial_correspondence( + self, interval: Union[int, ArrayLike] = 10, max_dist: Optional[int] = None, attr: Optional[Dict[str, Optional[str]]] = None, @@ -607,8 +562,8 @@ def _get_features( res[self.batch_key] = res[self.batch_key].astype("category") # type: ignore[call-overload] return res - def cell_transition( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + def cell_transition( + self, source: K, target: Optional[K] = None, source_groups: Optional[Str_Dict_t] = None, @@ -672,7 +627,7 @@ def cell_transition( # type: ignore[misc] return self._cell_transition( key=self.batch_key, source=source, - target=target, + target=target or None, source_groups=source_groups, target_groups=target_groups, forward=forward, @@ -684,16 +639,16 @@ def cell_transition( # type: ignore[misc] key_added=key_added, ) - def annotation_mapping( # type: ignore[misc] - self: SpatialMappingMixinProtocol[K, B], + def annotation_mapping( + self, mapping_mode: Literal["sum", "max"], annotation_label: str, source: K, - target: Union[K, str] = "tgt", + target: K = "tgt", forward: bool = False, batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -721,6 +676,8 @@ def annotation_mapping( # type: ignore[misc] If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + scale_by_marginals + Whether to scale by the source/target :term:`marginals`. Returns ------- @@ -736,7 +693,7 @@ def annotation_mapping( # type: ignore[misc] other_adata=self.adata_sc, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + scale_by_marginals=scale_by_marginals, ) @property @@ -746,7 +703,7 @@ def batch_key(self) -> Optional[str]: @batch_key.setter def batch_key(self, key: Optional[str]) -> None: - if key is not None and key not in self.adata.obs: # type: ignore[attr-defined] + if key is not None and key not in self.adata.obs: raise KeyError(f"Unable to find batch data in `adata.obs[{key!r}]`.") self._batch_key = key @@ -756,7 +713,7 @@ def spatial_key(self) -> Optional[str]: return self._spatial_key @spatial_key.setter - def spatial_key(self: SpatialAlignmentMixinProtocol[K, B], key: Optional[str]) -> None: # type: ignore[misc] + def spatial_key(self, key: Optional[str]) -> None: if key is not None and key not in self.adata.obsm: raise KeyError(f"Unable to find spatial data in `adata.obsm[{key!r}]`.") self._spatial_key = key diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 0fde8a334..c2537f6b4 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -1,20 +1,8 @@ from __future__ import annotations import itertools -import pathlib import types -from typing import ( - TYPE_CHECKING, - Any, - Iterable, - Iterator, - Literal, - Mapping, - Optional, - Protocol, - Sequence, - Union, -) +from typing import TYPE_CHECKING, Any, Literal, Mapping, Optional, Sequence, Union import numpy as np import pandas as pd @@ -24,128 +12,16 @@ from moscot import _constants from moscot._types import ArrayLike, Str_Dict_t -from moscot.base.problems._mixins import AnalysisMixin, AnalysisMixinProtocol -from moscot.base.problems.birth_death import BirthDeathProblem +from moscot.base.problems._mixins import AnalysisMixin from moscot.base.problems.compound_problem import ApplyOutput_t, B, K +from moscot.base.problems.problem import AbstractSolutionsProblems from moscot.plotting._utils import set_plotting_vars from moscot.utils.tagged_array import Tag __all__ = ["TemporalMixin"] -class TemporalMixinProtocol(AnalysisMixinProtocol[K, B], Protocol[K, B]): # type: ignore[misc] - adata: AnnData - problems: dict[tuple[K, K], BirthDeathProblem] - temporal_key: Optional[str] - _temporal_key: Optional[str] - - def cell_transition( # noqa: D102 - self: TemporalMixinProtocol[K, B], - source: K, - target: K, - source_groups: Str_Dict_t, - target_groups: Str_Dict_t, - forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic - aggregation_mode: Literal["annotation", "cell"] = "annotation", - batch_size: Optional[int] = None, - normalize: bool = True, - key_added: Optional[str] = _constants.CELL_TRANSITION, - ) -> pd.DataFrame: ... - - def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... - - def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]: ... - - def _cell_transition( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _annotation_mapping( - self: AnalysisMixinProtocol[K, B], - *args: Any, - **kwargs: Any, - ) -> pd.DataFrame: ... - - def _sample_from_tmap( - self: TemporalMixinProtocol[K, B], - source: K, - target: K, - n_samples: int, - source_dim: int, - target_dim: int, - batch_size: int = 256, - account_for_unbalancedness: bool = False, - interpolation_parameter: Optional[float] = None, - seed: Optional[int] = None, - ) -> tuple[list[Any], list[ArrayLike]]: ... - - def _compute_wasserstein_distance( - self: TemporalMixinProtocol[K, B], - point_cloud_1: ArrayLike, - point_cloud_2: ArrayLike, - a: Optional[ArrayLike] = None, - b: Optional[ArrayLike] = None, - backend: Literal["ott"] = "ott", - **kwargs: Any, - ) -> float: ... - - def _interpolate_gex_with_ot( - self: TemporalMixinProtocol[K, B], - number_cells: int, - source_data: ArrayLike, - target_data: ArrayLike, - source: K, - target: K, - interpolation_parameter: float, - account_for_unbalancedness: bool = True, - batch_size: int = 256, - seed: Optional[int] = None, - ) -> ArrayLike: ... - - def _get_data( - self: TemporalMixinProtocol[K, B], - source: K, - intermediate: Optional[K] = None, - target: Optional[K] = None, - posterior_marginals: bool = True, - *, - only_start: bool = False, - ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: ... - - def _interpolate_gex_randomly( - self: TemporalMixinProtocol[K, B], - number_cells: int, - source_data: ArrayLike, - target_data: ArrayLike, - interpolation_parameter: float, - growth_rates: Optional[ArrayLike] = None, - seed: Optional[int] = None, - ) -> ArrayLike: ... - - def _plot_temporal( - self: TemporalMixinProtocol[K, B], - data: dict[K, ArrayLike], - source: K, - target: K, - time_points: Optional[Iterable[K]] = None, - basis: str = "umap", - result_key: Optional[str] = None, - fill_value: float = 0.0, - save: Optional[Union[str, pathlib.Path]] = None, - **kwargs: Any, - ) -> None: ... - - @staticmethod - def _get_interp_param( - source: K, intermediate: K, target: K, interpolation_parameter: Optional[float] = None - ) -> float: ... - - def __iter__(self) -> Iterator[tuple[K, K]]: ... - - -class TemporalMixin(AnalysisMixin[K, B]): +class TemporalMixin(AnalysisMixin[K, B], AbstractSolutionsProblems): """Analysis Mixin for all problems involving a temporal dimension.""" def __init__(self, *args: Any, **kwargs: Any) -> None: @@ -153,7 +29,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._temporal_key: Optional[str] = None def cell_transition( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, source_groups: Str_Dict_t, @@ -228,7 +104,7 @@ def cell_transition( ) def annotation_mapping( - self: TemporalMixinProtocol[K, B], + self, mapping_mode: Literal["sum", "max"], annotation_label: str, forward: bool, @@ -236,7 +112,7 @@ def annotation_mapping( target: K, batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), - **kwargs: Mapping[str, Any], + scale_by_marginals: bool = True, ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -264,6 +140,8 @@ def annotation_mapping( If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. + scale_by_marginals + Whether to scale by the source/target :term:`marginals`. Returns ------- @@ -279,11 +157,11 @@ def annotation_mapping( other_adata=None, batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, - **kwargs, + scale_by_marginals=scale_by_marginals, ) def sankey( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, source_groups: Str_Dict_t, @@ -345,7 +223,7 @@ def sankey( :attr:`uns['moscot_results']['sankey']['{key_added}'] ` """ tuples = self._policy.plan(start=source, end=target) - cell_transitions = [] + cell_transitions: list[Any] = [] for src, tgt in tuples: cell_transitions.append( self.cell_transition( @@ -405,7 +283,7 @@ def sankey( set_plotting_vars(self.adata, _constants.SANKEY, key=key_added, value=plot_vars) # noqa: RET503 def push( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -472,7 +350,7 @@ def push( return result def pull( - self: TemporalMixinProtocol[K, B], + self, source: K, target: K, data: Optional[Union[str, ArrayLike]] = None, @@ -538,7 +416,7 @@ def pull( return result @property - def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def prior_growth_rates(self) -> Optional[pd.DataFrame]: """Prior estimate of the source growth rates.""" computed = [isinstance(p.prior_growth_rates, np.ndarray) for p in self.problems.values()] if not np.sum(computed): @@ -556,7 +434,7 @@ def prior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFra return pd.concat([df_1, df_2], verify_integrity=True) @property - def posterior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def posterior_growth_rates(self) -> Optional[pd.DataFrame]: """Posterior estimate of the source growth rates.""" computed = [isinstance(p.posterior_growth_rates, np.ndarray) for p in self.problems.values()] if not np.sum(computed): @@ -574,7 +452,7 @@ def posterior_growth_rates(self: TemporalMixinProtocol[K, B]) -> Optional[pd.Dat return pd.concat([df_1, df_2], verify_integrity=True) @property - def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def cell_costs_source(self) -> Optional[pd.DataFrame]: """Cell cost obtained by the :term:`first dual potential `. Only available for subproblems with :attr:`problem_kind = 'linear' `. @@ -587,7 +465,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.asarray(problem.solution.potentials[0]), # type: ignore[union-attr,index] + np.asarray(problem.solution.potentials[0]), index=problem.adata_src.obs_names, columns=cols, ) @@ -599,7 +477,7 @@ def cell_costs_source(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram return pd.concat([df_1, df_2], verify_integrity=True) @property - def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFrame]: + def cell_costs_target(self) -> Optional[pd.DataFrame]: """Cell cost obtained by the :term:`second dual potential `. Only available for subproblems with :attr:`problem_kind = 'linear' `. @@ -612,7 +490,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram # TODO(michalk8): `[1]` will fail if potentials is None df_list = [ pd.DataFrame( - np.array(problem.solution.potentials[1]), # type: ignore[union-attr,index] + np.array(problem.solution.potentials[1]), index=problem.adata_tgt.obs_names, columns=cols, ) @@ -624,7 +502,7 @@ def cell_costs_target(self: TemporalMixinProtocol[K, B]) -> Optional[pd.DataFram return pd.concat([df_1, df_2], verify_integrity=True) def _get_data( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: Optional[K] = None, target: Optional[K] = None, @@ -634,14 +512,11 @@ def _get_data( ) -> Union[tuple[ArrayLike, AnnData], tuple[ArrayLike, ArrayLike, ArrayLike, AnnData, ArrayLike]]: # TODO: use .items() for src, tgt in self.problems: - tag = self.problems[src, tgt].xy.tag # type: ignore[union-attr] + tag = self.problems[src, tgt].xy.tag if tag != Tag.POINT_CLOUD: - raise ValueError( - f"Expected `tag={Tag.POINT_CLOUD}`, " - f"found `tag={self.problems[src, tgt].xy.tag}`." # type: ignore[union-attr] - ) + raise ValueError(f"Expected `tag={Tag.POINT_CLOUD}`, " f"found `tag={self.problems[src, tgt].xy.tag}`.") if src == source: - source_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] + source_data = self.problems[src, tgt].xy.data_src if only_start: return source_data, self.problems[src, tgt].adata_src # TODO(michalk8): posterior marginals @@ -652,19 +527,19 @@ def _get_data( raise ValueError(f"No data found for `{source}` time point.") for src, tgt in self.problems: if src == intermediate: - intermediate_data = self.problems[src, tgt].xy.data_src # type: ignore[union-attr] + intermediate_data = self.problems[src, tgt].xy.data_src intermediate_adata = self.problems[src, tgt].adata_src break else: raise ValueError(f"No data found for `{intermediate}` time point.") for src, tgt in self.problems: if tgt == target: - target_data = self.problems[src, tgt].xy.data_tgt # type: ignore[union-attr] + target_data = self.problems[src, tgt].xy.data_tgt break else: raise ValueError(f"No data found for `{target}` time point.") - return ( # type:ignore[return-value] + return ( source_data, growth_rates_source, intermediate_data, @@ -673,7 +548,7 @@ def _get_data( ) def compute_interpolated_distance( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: K, target: K, @@ -757,7 +632,7 @@ def compute_interpolated_distance( return self._compute_wasserstein_distance(intermediate_data, interpolation, backend=backend, **kwargs) def compute_random_distance( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: K, target: K, @@ -831,7 +706,7 @@ def compute_random_distance( return self._compute_wasserstein_distance(intermediate_data, random_interpolation, backend=backend, **kwargs) def compute_time_point_distances( - self: TemporalMixinProtocol[K, B], + self, source: K, intermediate: K, target: K, @@ -884,7 +759,7 @@ def compute_time_point_distances( return distance_source_intermediate, distance_intermediate_target def compute_batch_distances( - self: TemporalMixinProtocol[K, B], + self, time: K, batch_key: str, posterior_marginals: bool = True, @@ -936,7 +811,7 @@ def compute_batch_distances( # TODO(@MUCDK) possibly offer two alternatives, once exact EMD with POT backend and once approximate, # faster with same solver as used for original problems def _compute_wasserstein_distance( - self: TemporalMixinProtocol[K, B], + self, point_cloud_1: ArrayLike, point_cloud_2: ArrayLike, a: Optional[ArrayLike] = None, @@ -951,7 +826,7 @@ def _compute_wasserstein_distance( raise NotImplementedError("Only `ott` available as backend.") def _interpolate_gex_with_ot( - self: TemporalMixinProtocol[K, B], + self, number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -979,7 +854,7 @@ def _interpolate_gex_with_ot( ) def _interpolate_gex_randomly( - self: TemporalMixinProtocol[K, B], + self, number_cells: int, source_data: ArrayLike, target_data: ArrayLike, @@ -1026,7 +901,7 @@ def temporal_key(self) -> Optional[str]: return self._temporal_key @temporal_key.setter - def temporal_key(self: TemporalMixinProtocol[K, B], key: Optional[str]) -> None: + def temporal_key(self, key: Optional[str]) -> None: if key is None: self._temporal_key = key return