Skip to content

Commit

Permalink
Merge branch 'main' into refactor/handle-solve-args
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen authored Nov 27, 2024
2 parents 9cb9c36 + 4080a94 commit aa593c6
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 464 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_stages:
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies: [numpy>=1.25.0]
Expand Down Expand Up @@ -42,12 +42,12 @@ repos:
- id: check-yaml
- id: check-toml
- repo: https://github.com/asottile/pyupgrade
rev: v3.18.0
rev: v3.19.0
hooks:
- id: pyupgrade
args: [--py3-plus, --py38-plus, --keep-runtime-typing]
- repo: https://github.com/asottile/blacken-docs
rev: 1.19.0
rev: 1.19.1
hooks:
- id: blacken-docs
additional_dependencies: [black==23.1.0]
Expand All @@ -63,7 +63,7 @@ repos:
- id: doc8
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.9
rev: v0.7.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax[neural]>=0.4.6",
"ott-jax[neural]>=0.4.6,<=0.4.8",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance>=2.0.0",
Expand Down
121 changes: 21 additions & 100 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
Literal,
Mapping,
Optional,
Protocol,
Sequence,
Union,
)
Expand All @@ -21,11 +20,9 @@
from scipy.sparse.linalg import LinearOperator

import scanpy as sc
from anndata import AnnData

from moscot import _constants
from moscot._types import ArrayLike, Numeric_t, Str_Dict_t
from moscot.base.output import BaseDiscreteSolverOutput
from moscot.base.problems._utils import (
_check_argument_compatibility_cell_transition,
_correlation_test,
Expand All @@ -34,104 +31,28 @@
_validate_annotations,
_validate_args_cell_transition,
)
from moscot.base.problems.compound_problem import ApplyOutput_t, B, K
from moscot.base.problems.compound_problem import B, K
from moscot.base.problems.problem import (
AbstractPushPullAdata,
AbstractSolutionsProblems,
)
from moscot.plotting._utils import set_plotting_vars
from moscot.utils.data import transcription_factors
from moscot.utils.subset_policy import SubsetPolicy

__all__ = ["AnalysisMixin"]


class AnalysisMixinProtocol(Protocol[K, B]):
"""Protocol class."""

adata: AnnData
_policy: SubsetPolicy[K]
solutions: dict[tuple[K, K], BaseDiscreteSolverOutput]
problems: dict[tuple[K, K], B]

def _apply(
self,
data: Optional[Union[str, ArrayLike]] = None,
source: Optional[K] = None,
target: Optional[K] = None,
forward: bool = True,
return_all: bool = False,
scale_by_marginals: bool = False,
**kwargs: Any,
) -> ApplyOutput_t[K]: ...

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
) -> LinearOperator: ...

def _flatten(
self: AnalysisMixinProtocol[K, B],
data: dict[K, ArrayLike],
*,
key: Optional[str],
) -> ArrayLike: ...

def push(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Push distribution."""
...

def pull(self, *args: Any, **kwargs: Any) -> Optional[ApplyOutput_t[K]]:
"""Pull distribution."""
...

def _cell_transition(
self: AnalysisMixinProtocol[K, B],
source: K,
target: K,
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
key_added: Optional[str] = _constants.CELL_TRANSITION,
**kwargs: Any,
) -> pd.DataFrame: ...

def _cell_transition_online(
self: AnalysisMixinProtocol[K, B],
key: Optional[str],
source: K,
target: K,
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic
aggregation_mode: Literal["annotation", "cell"] = "annotation",
other_key: Optional[str] = None,
other_adata: Optional[str] = None,
batch_size: Optional[int] = None,
normalize: bool = True,
) -> pd.DataFrame: ...

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
mapping_mode: Literal["sum", "max"],
annotation_label: str,
forward: bool,
source: K,
target: K,
key: str | None = None,
other_adata: Optional[str] = None,
scale_by_marginals: bool = True,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
) -> pd.DataFrame: ...


class AnalysisMixin(Generic[K, B]):
class AnalysisMixin(Generic[K, B], AbstractPushPullAdata, AbstractSolutionsProblems):
"""Base Analysis Mixin."""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def _cell_transition(
self: AnalysisMixinProtocol[K, B],
self,
source: K,
target: K,
target: Optional[K],
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
aggregation_mode: Literal["annotation", "cell"] = "annotation",
Expand Down Expand Up @@ -178,7 +99,7 @@ def _cell_transition(
return tm

def _annotation_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
self,
annotations_1: list[Any],
annotations_2: list[Any],
df: pd.DataFrame,
Expand Down Expand Up @@ -210,10 +131,10 @@ def _annotation_aggregation_transition(
)

def _cell_transition_online(
self: AnalysisMixinProtocol[K, B],
self,
key: Optional[str],
source: K,
target: K,
target: Optional[K],
source_groups: Str_Dict_t,
target_groups: Str_Dict_t,
forward: bool = False, # return value will be row-stochastic if forward=True, else column-stochastic
Expand Down Expand Up @@ -272,7 +193,7 @@ def _cell_transition_online(
split_mass=False,
**move_op_const_kwargs,
)
tm = self._annotation_aggregation_transition( # type: ignore[attr-defined]
tm = self._annotation_aggregation_transition(
annotations_1=source_annotations_verified if forward else target_annotations_verified,
annotations_2=target_annotations_verified if forward else source_annotations_verified,
df=df_to,
Expand All @@ -286,7 +207,7 @@ def _cell_transition_online(
split_mass=True,
**move_op_const_kwargs,
)
tm = self._cell_aggregation_transition( # type: ignore[attr-defined]
tm = self._cell_aggregation_transition(
df_from=df_from,
df_to=df_to,
annotations=target_annotations_verified if forward else source_annotations_verified,
Expand All @@ -309,7 +230,7 @@ def _cell_transition_online(
)

def _annotation_mapping(
self: AnalysisMixinProtocol[K, B],
self,
mapping_mode: Literal["sum", "max"],
annotation_label: str,
source: K,
Expand Down Expand Up @@ -394,7 +315,7 @@ def _annotation_mapping(
raise NotImplementedError(f"Mapping mode `{mapping_mode!r}` is not yet implemented.")

def _sample_from_tmap(
self: AnalysisMixinProtocol[K, B],
self,
source: K,
target: K,
n_samples: int,
Expand Down Expand Up @@ -472,7 +393,7 @@ def _sample_from_tmap(
return rows, all_cols_sampled # type: ignore[return-value]

def _interpolate_transport(
self: AnalysisMixinProtocol[K, B],
self,
# TODO(@giovp): rename this to 'explicit_steps', pass to policy.plan() and reintroduce (source_key, target_key)
path: Sequence[tuple[K, K]],
scale_by_marginals: bool = True,
Expand All @@ -485,15 +406,15 @@ def _interpolate_transport(
fst, *rest = path
return self.solutions[fst].chain([self.solutions[r] for r in rest], scale_by_marginals=scale_by_marginals)

def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
def _flatten(self, data: dict[K, ArrayLike], *, key: Optional[str]) -> ArrayLike:
tmp = np.full(len(self.adata), np.nan)
for k, v in data.items():
mask = self.adata.obs[key] == k
tmp[mask] = np.squeeze(v)
return tmp

def _cell_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
self,
df_from: pd.DataFrame,
df_to: pd.DataFrame,
annotations: list[Any],
Expand Down Expand Up @@ -540,7 +461,7 @@ def _cell_aggregation_transition(
# adapted from:
# https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392
def compute_feature_correlation(
self: AnalysisMixinProtocol[K, B],
self,
obs_key: str,
corr_method: Literal["pearson", "spearman"] = "pearson",
significance_method: Literal["fisher", "perm_test"] = "fisher",
Expand Down Expand Up @@ -649,7 +570,7 @@ def compute_feature_correlation(
)

def compute_entropy(
self: AnalysisMixinProtocol[K, B],
self,
source: K,
target: K,
forward: bool = True,
Expand Down Expand Up @@ -706,7 +627,7 @@ def compute_entropy(
split_mass=True,
key_added=None,
)
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs) # type: ignore[operator]
df.iloc[range(batch, min(batch + batch_size, len(df))), 0] = stats.entropy(cond_dists + c, **kwargs)
if key_added is not None:
self.adata.obs[key_added] = df
return df if key_added is None else None
51 changes: 8 additions & 43 deletions src/moscot/base/problems/birth_death.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Protocol,
Sequence,
Union,
)
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Sequence, Union

import numpy as np

Expand All @@ -17,39 +8,13 @@

from moscot._logging import logger
from moscot._types import ArrayLike
from moscot.base.problems.problem import OTProblem
from moscot.base.problems.problem import AbstractAdataAccess, OTProblem
from moscot.utils.data import apoptosis_markers, proliferation_markers

__all__ = ["BirthDeathProblem", "BirthDeathMixin"]


class BirthDeathProtocol(Protocol): # noqa: D101
adata: AnnData
proliferation_key: Optional[str]
apoptosis_key: Optional[str]
_proliferation_key: Optional[str]
_apoptosis_key: Optional[str]
_scaling: float
_prior_growth: Optional[ArrayLike]

def score_genes_for_marginals( # noqa: D102
self: "BirthDeathProtocol",
gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
proliferation_key: str = "proliferation",
apoptosis_key: str = "apoptosis",
**kwargs: Any,
) -> "BirthDeathProtocol": ...


class BirthDeathProblemProtocol(BirthDeathProtocol, Protocol): # noqa: D101
delta: float
adata_tgt: AnnData
a: Optional[ArrayLike]
b: Optional[ArrayLike]


class BirthDeathMixin:
class BirthDeathMixin(AbstractAdataAccess):
"""Mixin class used to estimate cell proliferation and apoptosis.
Parameters
Expand All @@ -68,7 +33,7 @@ def __init__(self, *args: Any, **kwargs: Any):
self._prior_growth: Optional[ArrayLike] = None

def score_genes_for_marginals(
self, # type: BirthDeathProblemProtocol
self,
gene_set_proliferation: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
gene_set_apoptosis: Optional[Union[Literal["human", "mouse"], Sequence[str]]] = None,
proliferation_key: str = "proliferation",
Expand Down Expand Up @@ -123,15 +88,15 @@ def score_genes_for_marginals(
"At least one of `gene_set_proliferation` or `gene_set_apoptosis` must be provided to score genes."
)

return self # type: ignore[return-value]
return self

@property
def proliferation_key(self) -> Optional[str]:
"""Key in :attr:`~anndata.AnnData.obs` where cell proliferation is stored."""
return self._proliferation_key

@proliferation_key.setter
def proliferation_key(self: BirthDeathProtocol, key: Optional[str]) -> None:
def proliferation_key(self, key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find proliferation data in `adata.obs[{key!r}]`.")
self._proliferation_key = key
Expand All @@ -142,7 +107,7 @@ def apoptosis_key(self) -> Optional[str]:
return self._apoptosis_key

@apoptosis_key.setter
def apoptosis_key(self: BirthDeathProtocol, key: Optional[str]) -> None:
def apoptosis_key(self, key: Optional[str]) -> None:
if key is not None and key not in self.adata.obs:
raise KeyError(f"Unable to find apoptosis data in `adata.obs[{key!r}]`.")
self._apoptosis_key = key
Expand All @@ -161,7 +126,7 @@ class BirthDeathProblem(BirthDeathMixin, OTProblem):
""" # noqa: D205

def estimate_marginals(
self, # type: BirthDeathProblemProtocol
self,
adata: AnnData,
source: bool,
proliferation_key: Optional[str] = None,
Expand Down
2 changes: 1 addition & 1 deletion src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
ApplyOutput_t = Union[ArrayLike, Dict[K, ArrayLike]]


class BaseCompoundProblem(BaseProblem, abc.ABC, Generic[K, B]):
class BaseCompoundProblem(BaseProblem, Generic[K, B], abc.ABC):
"""Base class for all biological problems.
This class translates a biological problem to multiple :term:`OT` problems.
Expand Down
Loading

0 comments on commit aa593c6

Please sign in to comment.