From 42550e387ff48d0180c6897ed6d947211ef514e9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 7 Oct 2024 15:19:46 +0200 Subject: [PATCH 1/4] fix the bug and set observed=False for future pandas and warning --- src/moscot/base/problems/_mixins.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 9c460eac2..c87cb373d 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -198,22 +198,24 @@ def _cell_transition_online( ) df_source = _get_df_cell_transition( self.adata, - [source_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], + [source_annotation_key], key, source, ) df_target = _get_df_cell_transition( self.adata if other_adata is None else other_adata, - [target_annotation_key] if aggregation_mode == "cell" else [source_annotation_key, target_annotation_key], + [target_annotation_key], key if other_adata is None else other_key, target, ) - + df_source = df_source.rename(columns={source_annotation_key: "res_annotation"}) + df_target = df_target.rename(columns={target_annotation_key: "res_annotation"}) + res_annotation_key = "res_annotation" source_annotations_verified, target_annotations_verified = _validate_annotations( df_source=df_source, df_target=df_target, - source_annotation_key=source_annotation_key, - target_annotation_key=target_annotation_key, + source_annotation_key=res_annotation_key, + target_annotation_key=res_annotation_key, source_annotations=source_annotations, target_annotations=target_annotations, aggregation_mode=aggregation_mode, @@ -236,6 +238,7 @@ def _cell_transition_online( annotations_1=source_annotations_verified, annotations_2=target_annotations_verified, df=df_target, + df_key=res_annotation_key, tm=tm, forward=True, ) @@ -247,6 +250,7 @@ def _cell_transition_online( annotations_1=target_annotations_verified, annotations_2=source_annotations_verified, df=df_source, + df_key=res_annotation_key, tm=tm, forward=False, ) @@ -256,7 +260,7 @@ def _cell_transition_online( tm = self._cell_aggregation_transition( # type: ignore[attr-defined] source=source, target=target, - annotation_key=target_annotation_key, + annotation_key=res_annotation_key, annotations_1=source_annotations_verified, annotations_2=target_annotations_verified, df_1=df_target, @@ -269,7 +273,7 @@ def _cell_transition_online( tm = self._cell_aggregation_transition( # type: ignore[attr-defined] source=source, target=target, - annotation_key=source_annotation_key, + annotation_key=res_annotation_key, annotations_1=target_annotations_verified, annotations_2=source_annotations_verified, df_1=df_source, @@ -483,6 +487,7 @@ def _annotation_aggregation_transition( annotation_key: str, annotations_1: list[Any], annotations_2: list[Any], + df_key: str, df: pd.DataFrame, tm: pd.DataFrame, forward: bool, @@ -503,7 +508,7 @@ def _annotation_aggregation_transition( key_added=None, ) df["distribution"] = result - cell_dist = df[df[annotation_key].isin(annotations_2)].groupby(annotation_key).sum(numeric_only=True) + cell_dist = df[df[df_key].isin(annotations_2)].groupby(df_key, observed=False).sum(numeric_only=True) cell_dist /= cell_dist.sum() tm.loc[subset, :] = [ cell_dist.loc[annotation, "distribution"] if annotation in cell_dist.distribution.index else 0 From 8f70c1ebe959ea69339c6d47f7ecf3f45a624147 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 9 Oct 2024 17:35:34 +0200 Subject: [PATCH 2/4] update _cell_transition_online --- src/moscot/base/problems/_mixins.py | 237 +++++++++++++--------------- 1 file changed, 112 insertions(+), 125 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index c87cb373d..059a112bb 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -1,9 +1,11 @@ from __future__ import annotations import types +from functools import partial from typing import ( TYPE_CHECKING, Any, + Callable, Generic, Iterable, Literal, @@ -175,6 +177,38 @@ def _cell_transition( ) return tm + def _annotation_aggregation_transition( + self: AnalysisMixinProtocol[K, B], + annotations_1: list[Any], + annotations_2: list[Any], + df: pd.DataFrame, + func: Callable[..., ArrayLike], + ) -> pd.DataFrame: + n1 = len(annotations_1) + n2 = len(annotations_2) + tm_arr = np.zeros((n1, n2)) + + # Factorize annotations in df_res_annotation + codes, uniques = pd.factorize(df.values) + # Map annotations in 'annotations_2' to indices in 'uniques' + annotations_in_df_to_idx = {annotation: idx for idx, annotation in enumerate(uniques)} + annotations_2_codes = [annotations_in_df_to_idx.get(annotation, -1) for annotation in annotations_2] + + for i, subset in enumerate(annotations_1): + result = func( + subset=subset, + ) + # Compute sums over 'codes' weighted by 'result' + sums = np.bincount(codes, weights=result.squeeze(), minlength=len(uniques)) + dist = [sums[code] if code != -1 else 0 for code in annotations_2_codes] + tm_arr[i, :] = dist + + return pd.DataFrame( + tm_arr, + index=annotations_1, + columns=annotations_2, + ) + def _cell_transition_online( self: AnalysisMixinProtocol[K, B], key: Optional[str], @@ -208,9 +242,9 @@ def _cell_transition_online( key if other_adata is None else other_key, target, ) - df_source = df_source.rename(columns={source_annotation_key: "res_annotation"}) - df_target = df_target.rename(columns={target_annotation_key: "res_annotation"}) res_annotation_key = "res_annotation" + df_source = df_source.rename(columns={source_annotation_key: res_annotation_key}) + df_target = df_target.rename(columns={target_annotation_key: res_annotation_key}) source_annotations_verified, target_annotations_verified = _validate_annotations( df_source=df_source, df_target=df_target, @@ -221,67 +255,46 @@ def _cell_transition_online( aggregation_mode=aggregation_mode, forward=forward, ) - + move_op = self.push if forward else self.pull + move_op_const_kwargs = { + "source": source, + "target": target, + "normalize": True, + "return_all": False, + "scale_by_marginals": False, + "key_added": None, + } if aggregation_mode == "annotation": - df_target["distribution"] = 0 - df_source["distribution"] = 0 - tm = pd.DataFrame( - np.zeros((len(source_annotations_verified), len(target_annotations_verified))), - index=source_annotations_verified, - columns=target_annotations_verified, + func = partial( + move_op, + data=source_annotation_key if forward else target_annotation_key, + split_mass=False, + **move_op_const_kwargs, ) - if forward: - tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=source_annotation_key, - annotations_1=source_annotations_verified, - annotations_2=target_annotations_verified, - df=df_target, - df_key=res_annotation_key, - tm=tm, - forward=True, - ) - else: - tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=target_annotation_key, - annotations_1=target_annotations_verified, - annotations_2=source_annotations_verified, - df=df_source, - df_key=res_annotation_key, - tm=tm, - forward=False, - ) + df = (df_target if forward else df_source)[res_annotation_key] + tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] + annotations_1=source_annotations_verified if forward else target_annotations_verified, + annotations_2=target_annotations_verified if forward else source_annotations_verified, + df=df, + func=func, + ) + elif aggregation_mode == "cell": - tm = pd.DataFrame(columns=target_annotations_verified if forward else source_annotations_verified) - if forward: - tm = self._cell_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=res_annotation_key, - annotations_1=source_annotations_verified, - annotations_2=target_annotations_verified, - df_1=df_target, - df_2=df_source, - tm=tm, - batch_size=batch_size, - forward=True, - ) - else: - tm = self._cell_aggregation_transition( # type: ignore[attr-defined] - source=source, - target=target, - annotation_key=res_annotation_key, - annotations_1=target_annotations_verified, - annotations_2=source_annotations_verified, - df_1=df_source, - df_2=df_target, - tm=tm, - batch_size=batch_size, - forward=False, - ) + func = partial( + move_op, + data=None, + split_mass=True, + **move_op_const_kwargs, + ) + tm = self._cell_aggregation_transition( # type: ignore[attr-defined] + df_from=df_source if forward else df_target, + df_to=df_target if forward else df_source, + annotation_key=res_annotation_key, + annotations=target_annotations_verified if forward else source_annotations_verified, + batch_size=batch_size, + func=func, + ) + else: raise NotImplementedError(f"Aggregation mode `{aggregation_mode!r}` is not yet implemented.") @@ -480,77 +493,51 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key tmp[mask] = np.squeeze(v) return tmp - def _annotation_aggregation_transition( - self: AnalysisMixinProtocol[K, B], - source: K, - target: K, - annotation_key: str, - annotations_1: list[Any], - annotations_2: list[Any], - df_key: str, - df: pd.DataFrame, - tm: pd.DataFrame, - forward: bool, - ) -> pd.DataFrame: - if not forward: - tm = tm.T - func = self.push if forward else self.pull - for subset in annotations_1: - result = func( # TODO(@MUCDK) check how to make compatible with all policies - source=source, - target=target, - data=annotation_key, - subset=subset, - normalize=True, - return_all=False, - scale_by_marginals=False, - split_mass=False, - key_added=None, - ) - df["distribution"] = result - cell_dist = df[df[df_key].isin(annotations_2)].groupby(df_key, observed=False).sum(numeric_only=True) - cell_dist /= cell_dist.sum() - tm.loc[subset, :] = [ - cell_dist.loc[annotation, "distribution"] if annotation in cell_dist.distribution.index else 0 - for annotation in annotations_2 - ] - return tm - def _cell_aggregation_transition( self: AnalysisMixinProtocol[K, B], - source: str, - target: str, annotation_key: str, - # TODO(MUCDK): unused variables, del below - annotations_1: list[Any], - annotations_2: list[Any], - df_1: pd.DataFrame, - df_2: pd.DataFrame, - tm: pd.DataFrame, + df_from: pd.DataFrame, + df_to: pd.DataFrame, + annotations: list[Any], batch_size: Optional[int], - forward: bool, + func: Callable[..., ArrayLike], ) -> pd.DataFrame: - func = self.push if forward else self.pull + + # Factorize annotations in df_to + annotations_in_df_to = df_to[annotation_key].values + codes_to, uniques_to = pd.factorize(annotations_in_df_to) + # Map annotations in 'annotations' to codes + annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)} + annotations_codes = [annotations_to_code.get(annotation, -1) for annotation in annotations] + n_annotations = len(annotations) + n_from_cells = len(df_from) + if batch_size is None: - batch_size = len(df_2) - for batch in range(0, len(df_2), batch_size): - result = func( # TODO(@MUCDK) check how to make compatible with all policies - source=source, - target=target, - data=None, - subset=(batch, batch_size), - normalize=True, - return_all=False, - scale_by_marginals=False, - split_mass=True, - key_added=None, - ) - current_cells = df_2.iloc[range(batch, min(batch + batch_size, len(df_2)))].index.tolist() - df_1.loc[:, current_cells] = result - to_app = df_1[df_1[annotation_key].isin(annotations_2)].groupby(annotation_key).sum().transpose() - tm = pd.concat([tm, to_app], verify_integrity=True, axis=0) - df_1 = df_1.drop(current_cells, axis=1) - return tm + batch_size = n_from_cells + + tm_arr = np.zeros((n_from_cells, n_annotations)) + index = df_from.index + + # Process in batches + for batch_start in range(0, n_from_cells, batch_size): + batch_end = min(batch_start + batch_size, n_from_cells) + subset = (batch_start, batch_end - batch_start) + result = func(subset=subset) + # Result shape: (n_to_cells, batch_size) + # For each cell in the batch, we compute the sum over annotations + for i in range(batch_end - batch_start): + cell_distribution = result[:, i] + # Aggregate over annotations using bincount + sums = np.bincount( + codes_to, + weights=cell_distribution, + minlength=len(uniques_to), + ) + # Map sums to annotations_verified_codes + dist = [sums[code] if code != -1 else 0 for code in annotations_codes] + tm_arr[batch_start + i, :] = dist + + return pd.DataFrame(tm_arr, index=index, columns=annotations) # adapted from: # https://github.com/theislab/cellrank/blob/master/cellrank/_utils/_utils.py#L392 From 159eba0b38e7759d57fdb154aac559b8da1b3c9b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 9 Oct 2024 18:04:23 +0200 Subject: [PATCH 3/4] refactor some functions --- src/moscot/base/problems/_mixins.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 059a112bb..aaa7dd459 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -264,6 +264,9 @@ def _cell_transition_online( "scale_by_marginals": False, "key_added": None, } + df_to, df_from = (df_target, df_source) if forward else (df_source, df_target) + df_to = df_to[res_annotation_key] + if aggregation_mode == "annotation": func = partial( move_op, @@ -271,11 +274,10 @@ def _cell_transition_online( split_mass=False, **move_op_const_kwargs, ) - df = (df_target if forward else df_source)[res_annotation_key] tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] annotations_1=source_annotations_verified if forward else target_annotations_verified, annotations_2=target_annotations_verified if forward else source_annotations_verified, - df=df, + df=df_to, func=func, ) @@ -287,9 +289,8 @@ def _cell_transition_online( **move_op_const_kwargs, ) tm = self._cell_aggregation_transition( # type: ignore[attr-defined] - df_from=df_source if forward else df_target, - df_to=df_target if forward else df_source, - annotation_key=res_annotation_key, + df_from=df_from, + df_to=df_to, annotations=target_annotations_verified if forward else source_annotations_verified, batch_size=batch_size, func=func, @@ -495,7 +496,6 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key def _cell_aggregation_transition( self: AnalysisMixinProtocol[K, B], - annotation_key: str, df_from: pd.DataFrame, df_to: pd.DataFrame, annotations: list[Any], @@ -504,7 +504,7 @@ def _cell_aggregation_transition( ) -> pd.DataFrame: # Factorize annotations in df_to - annotations_in_df_to = df_to[annotation_key].values + annotations_in_df_to = df_to.values codes_to, uniques_to = pd.factorize(annotations_in_df_to) # Map annotations in 'annotations' to codes annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)} From 92acfbe018c4a148b5f7b317c2ef7f04025d0d16 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Wed, 9 Oct 2024 18:18:49 +0200 Subject: [PATCH 4/4] remove some unnecesary lines --- src/moscot/base/problems/_mixins.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index aaa7dd459..2a5ba2029 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -230,31 +230,31 @@ def _cell_transition_online( target_annotation_key, target_annotations, target_annotations_ordered = _validate_args_cell_transition( self.adata if other_adata is None else other_adata, target_groups ) + new_annotation_key = "new_annotation" df_source = _get_df_cell_transition( self.adata, [source_annotation_key], key, source, - ) + ).rename(columns={source_annotation_key: new_annotation_key}) df_target = _get_df_cell_transition( self.adata if other_adata is None else other_adata, [target_annotation_key], key if other_adata is None else other_key, target, - ) - res_annotation_key = "res_annotation" - df_source = df_source.rename(columns={source_annotation_key: res_annotation_key}) - df_target = df_target.rename(columns={target_annotation_key: res_annotation_key}) + ).rename(columns={target_annotation_key: new_annotation_key}) source_annotations_verified, target_annotations_verified = _validate_annotations( df_source=df_source, df_target=df_target, - source_annotation_key=res_annotation_key, - target_annotation_key=res_annotation_key, + source_annotation_key=new_annotation_key, + target_annotation_key=new_annotation_key, source_annotations=source_annotations, target_annotations=target_annotations, aggregation_mode=aggregation_mode, forward=forward, ) + df_to, df_from = (df_target, df_source) if forward else (df_source, df_target) + df_to = df_to[new_annotation_key] move_op = self.push if forward else self.pull move_op_const_kwargs = { "source": source, @@ -264,8 +264,6 @@ def _cell_transition_online( "scale_by_marginals": False, "key_added": None, } - df_to, df_from = (df_target, df_source) if forward else (df_source, df_target) - df_to = df_to[res_annotation_key] if aggregation_mode == "annotation": func = partial(