diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index 059a112bb..aaa7dd459 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -264,6 +264,9 @@ def _cell_transition_online( "scale_by_marginals": False, "key_added": None, } + df_to, df_from = (df_target, df_source) if forward else (df_source, df_target) + df_to = df_to[res_annotation_key] + if aggregation_mode == "annotation": func = partial( move_op, @@ -271,11 +274,10 @@ def _cell_transition_online( split_mass=False, **move_op_const_kwargs, ) - df = (df_target if forward else df_source)[res_annotation_key] tm = self._annotation_aggregation_transition( # type: ignore[attr-defined] annotations_1=source_annotations_verified if forward else target_annotations_verified, annotations_2=target_annotations_verified if forward else source_annotations_verified, - df=df, + df=df_to, func=func, ) @@ -287,9 +289,8 @@ def _cell_transition_online( **move_op_const_kwargs, ) tm = self._cell_aggregation_transition( # type: ignore[attr-defined] - df_from=df_source if forward else df_target, - df_to=df_target if forward else df_source, - annotation_key=res_annotation_key, + df_from=df_from, + df_to=df_to, annotations=target_annotations_verified if forward else source_annotations_verified, batch_size=batch_size, func=func, @@ -495,7 +496,6 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key def _cell_aggregation_transition( self: AnalysisMixinProtocol[K, B], - annotation_key: str, df_from: pd.DataFrame, df_to: pd.DataFrame, annotations: list[Any], @@ -504,7 +504,7 @@ def _cell_aggregation_transition( ) -> pd.DataFrame: # Factorize annotations in df_to - annotations_in_df_to = df_to[annotation_key].values + annotations_in_df_to = df_to.values codes_to, uniques_to = pd.factorize(annotations_in_df_to) # Map annotations in 'annotations' to codes annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)}