Skip to content

Commit

Permalink
refactor some functions
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Oct 9, 2024
1 parent 8f70c1e commit 159eba0
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,18 +264,20 @@ def _cell_transition_online(
"scale_by_marginals": False,
"key_added": None,
}
df_to, df_from = (df_target, df_source) if forward else (df_source, df_target)
df_to = df_to[res_annotation_key]

if aggregation_mode == "annotation":
func = partial(
move_op,
data=source_annotation_key if forward else target_annotation_key,
split_mass=False,
**move_op_const_kwargs,
)
df = (df_target if forward else df_source)[res_annotation_key]
tm = self._annotation_aggregation_transition( # type: ignore[attr-defined]
annotations_1=source_annotations_verified if forward else target_annotations_verified,
annotations_2=target_annotations_verified if forward else source_annotations_verified,
df=df,
df=df_to,
func=func,
)

Expand All @@ -287,9 +289,8 @@ def _cell_transition_online(
**move_op_const_kwargs,
)
tm = self._cell_aggregation_transition( # type: ignore[attr-defined]
df_from=df_source if forward else df_target,
df_to=df_target if forward else df_source,
annotation_key=res_annotation_key,
df_from=df_from,
df_to=df_to,
annotations=target_annotations_verified if forward else source_annotations_verified,
batch_size=batch_size,
func=func,
Expand Down Expand Up @@ -495,7 +496,6 @@ def _flatten(self: AnalysisMixinProtocol[K, B], data: dict[K, ArrayLike], *, key

def _cell_aggregation_transition(
self: AnalysisMixinProtocol[K, B],
annotation_key: str,
df_from: pd.DataFrame,
df_to: pd.DataFrame,
annotations: list[Any],
Expand All @@ -504,7 +504,7 @@ def _cell_aggregation_transition(
) -> pd.DataFrame:

# Factorize annotations in df_to
annotations_in_df_to = df_to[annotation_key].values
annotations_in_df_to = df_to.values
codes_to, uniques_to = pd.factorize(annotations_in_df_to)
# Map annotations in 'annotations' to codes
annotations_to_code = {annotation: idx for idx, annotation in enumerate(uniques_to)}
Expand Down

0 comments on commit 159eba0

Please sign in to comment.