diff --git a/src/moscot/base/problems/_mixins.py b/src/moscot/base/problems/_mixins.py index cd09cd9b9..dc76889fd 100644 --- a/src/moscot/base/problems/_mixins.py +++ b/src/moscot/base/problems/_mixins.py @@ -320,6 +320,7 @@ def _annotation_mapping( cell_transition_kwargs.setdefault("target", target) cell_transition_kwargs.setdefault("other_adata", other_adata) cell_transition_kwargs.setdefault("forward", not forward) + cell_transition_kwargs.setdefault("batch_size", batch_size) if forward: cell_transition_kwargs.setdefault("source_groups", annotation_label) cell_transition_kwargs.setdefault("target_groups", None) @@ -342,7 +343,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[1] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.push( + tm_batch: ArrayLike = self.pull( source=source, target=target, data=None, @@ -353,7 +354,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = np.array(tm_batch.argmax(1)) + v = np.array(tm_batch.argmax(0)) out.extend(source_df[annotation_label][v[i]] for i in range(len(v))) else: @@ -366,7 +367,7 @@ def _annotation_mapping( out_len = self.solutions[(source, target)].shape[0] batch_size = batch_size if batch_size is not None else out_len for batch in range(0, out_len, batch_size): - tm_batch: ArrayLike = self.pull( # type: ignore[no-redef] + tm_batch: ArrayLike = self.push( # type: ignore[no-redef] source=source, target=target, data=None, @@ -377,7 +378,7 @@ def _annotation_mapping( split_mass=True, key_added=None, ) - v = np.array(tm_batch.argmax(1)) + v = np.array(tm_batch.argmax(0)) out.extend(target_df[annotation_label][v[i]] for i in range(len(v))) categories = pd.Categorical(out) return pd.DataFrame(categories, columns=[annotation_label]) diff --git a/src/moscot/problems/cross_modality/_mixins.py b/src/moscot/problems/cross_modality/_mixins.py index ce58f84a4..5299e1dab 100644 --- a/src/moscot/problems/cross_modality/_mixins.py +++ b/src/moscot/problems/cross_modality/_mixins.py @@ -194,7 +194,9 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -216,6 +218,10 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -231,7 +237,9 @@ def annotation_mapping( # type: ignore[misc] key=self.batch_key, forward=forward, other_adata=self.adata_tgt, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, + **kwargs, ) @property diff --git a/src/moscot/problems/space/_mixins.py b/src/moscot/problems/space/_mixins.py index 0aa84c326..d351eb821 100644 --- a/src/moscot/problems/space/_mixins.py +++ b/src/moscot/problems/space/_mixins.py @@ -291,7 +291,9 @@ def annotation_mapping( # type: ignore[misc] forward: bool, source: str = "src", target: str = "tgt", + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -313,6 +315,10 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -327,7 +333,9 @@ def annotation_mapping( # type: ignore[misc] target=target, key=self.batch_key, forward=forward, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, + **kwargs, ) @property @@ -626,7 +634,9 @@ def annotation_mapping( # type: ignore[misc] source: K, target: Union[K, str] = "tgt", forward: bool = False, + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -648,6 +658,10 @@ def annotation_mapping( # type: ignore[misc] Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -663,7 +677,9 @@ def annotation_mapping( # type: ignore[misc] forward=forward, key=self.batch_key, other_adata=self.adata_sc, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, + **kwargs, ) @property diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index 597757901..c2e940e79 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -247,7 +247,9 @@ def annotation_mapping( forward: bool, source: K, target: K, + batch_size: Optional[int] = None, cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}), + **kwargs: Mapping[str, Any], ) -> pd.DataFrame: """Transfer annotations between distributions. @@ -269,6 +271,10 @@ def annotation_mapping( Key identifying the source distribution. target Key identifying the target distribution. + batch_size + Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`. + Larger value will require more memory. + If :obj:`None`, the entire cost matrix will be materialized. cell_transition_kwargs Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``. @@ -284,7 +290,9 @@ def annotation_mapping( key=self._temporal_key, forward=forward, other_adata=None, + batch_size=batch_size, cell_transition_kwargs=cell_transition_kwargs, + **kwargs, ) def sankey( diff --git a/tests/problems/cross_modality/test_mixins.py b/tests/problems/cross_modality/test_mixins.py index 079e153a4..781c7c8d1 100644 --- a/tests/problems/cross_modality/test_mixins.py +++ b/tests/problems/cross_modality/test_mixins.py @@ -110,9 +110,10 @@ def test_cell_transition_pipeline( @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["cross_modality"]) def test_annotation_mapping( - self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation + self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, batch_size, gt_tm_annotation ): adata_src, adata_tgt = adata_anno tp = TranslationProblem(adata_src, adata_tgt) @@ -122,7 +123,12 @@ def test_annotation_mapping( tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True) annotation_label = "celltype1" if forward else "celltype2" result = tp.annotation_mapping( - mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt" + mapping_mode=mapping_mode, + annotation_label=annotation_label, + forward=forward, + source="src", + target="tgt", + batch_size=batch_size, ) if forward: expected_result = ( diff --git a/tests/problems/space/test_mixins.py b/tests/problems/space/test_mixins.py index a6b70031c..c105683a9 100644 --- a/tests/problems/space/test_mixins.py +++ b/tests/problems/space/test_mixins.py @@ -96,8 +96,9 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["alignment"]) - def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): ap = AlignmentProblem(adata=adata_anno) ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"}) problem_keys = ("0", "1") @@ -110,6 +111,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo source="0", target="1", forward=forward, + batch_size=batch_size, ) if forward: expected_result = ( @@ -207,8 +209,9 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["mapping"]) - def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): adataref, adatasp = adata_anno mp = MappingProblem(adataref, adatasp) mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"}) @@ -221,6 +224,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo annotation_label=annotation_label, source="src", forward=forward, + batch_size=batch_size, ) if not forward: expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"] diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index cb2d9ea2a..e5c9e36f8 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -53,8 +53,9 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward @pytest.mark.fast() @pytest.mark.parametrize("forward", [True, False]) @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) @pytest.mark.parametrize("problem_kind", ["temporal"]) - def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation): + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): problem = TemporalProblem(adata_anno) problem_keys = (0, 1) problem = problem.prepare(time_key="day", joint_attr="X_pca") @@ -62,7 +63,12 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) annotation_label = "celltype1" if forward else "celltype2" result = problem.annotation_mapping( - mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1 + mapping_mode=mapping_mode, + annotation_label=annotation_label, + forward=forward, + source=0, + target=1, + batch_size=batch_size, ) if forward: expected_result = (