Skip to content

Commit

Permalink
passing batch_size fix (#647)
Browse files Browse the repository at this point in the history
* passing batch_size fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* index not a scalar fix

* expose batch_size, parametrize it in tests

* Optional instead of | None

* if batch_size is None

* if batch_size is None

* line order

---------

Co-authored-by: Arina Danilina <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 23, 2024
1 parent 05530db commit 4617c94
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 10 deletions.
9 changes: 5 additions & 4 deletions src/moscot/base/problems/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def _annotation_mapping(
cell_transition_kwargs.setdefault("target", target)
cell_transition_kwargs.setdefault("other_adata", other_adata)
cell_transition_kwargs.setdefault("forward", not forward)
cell_transition_kwargs.setdefault("batch_size", batch_size)
if forward:
cell_transition_kwargs.setdefault("source_groups", annotation_label)
cell_transition_kwargs.setdefault("target_groups", None)
Expand All @@ -342,7 +343,7 @@ def _annotation_mapping(
out_len = self.solutions[(source, target)].shape[1]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.push(
tm_batch: ArrayLike = self.pull(
source=source,
target=target,
data=None,
Expand All @@ -353,7 +354,7 @@ def _annotation_mapping(
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(1))
v = np.array(tm_batch.argmax(0))
out.extend(source_df[annotation_label][v[i]] for i in range(len(v)))

else:
Expand All @@ -366,7 +367,7 @@ def _annotation_mapping(
out_len = self.solutions[(source, target)].shape[0]
batch_size = batch_size if batch_size is not None else out_len
for batch in range(0, out_len, batch_size):
tm_batch: ArrayLike = self.pull( # type: ignore[no-redef]
tm_batch: ArrayLike = self.push( # type: ignore[no-redef]
source=source,
target=target,
data=None,
Expand All @@ -377,7 +378,7 @@ def _annotation_mapping(
split_mass=True,
key_added=None,
)
v = np.array(tm_batch.argmax(1))
v = np.array(tm_batch.argmax(0))
out.extend(target_df[annotation_label][v[i]] for i in range(len(v)))
categories = pd.Categorical(out)
return pd.DataFrame(categories, columns=[annotation_label])
Expand Down
8 changes: 8 additions & 0 deletions src/moscot/problems/cross_modality/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def annotation_mapping( # type: ignore[misc]
forward: bool,
source: str = "src",
target: str = "tgt",
batch_size: Optional[int] = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Mapping[str, Any],
) -> pd.DataFrame:
"""Transfer annotations between distributions.
Expand All @@ -216,6 +218,10 @@ def annotation_mapping( # type: ignore[misc]
Key identifying the source distribution.
target
Key identifying the target distribution.
batch_size
Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`.
Larger value will require more memory.
If :obj:`None`, the entire cost matrix will be materialized.
cell_transition_kwargs
Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.
Expand All @@ -231,7 +237,9 @@ def annotation_mapping( # type: ignore[misc]
key=self.batch_key,
forward=forward,
other_adata=self.adata_tgt,
batch_size=batch_size,
cell_transition_kwargs=cell_transition_kwargs,
**kwargs,
)

@property
Expand Down
16 changes: 16 additions & 0 deletions src/moscot/problems/space/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,9 @@ def annotation_mapping( # type: ignore[misc]
forward: bool,
source: str = "src",
target: str = "tgt",
batch_size: Optional[int] = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Mapping[str, Any],
) -> pd.DataFrame:
"""Transfer annotations between distributions.
Expand All @@ -313,6 +315,10 @@ def annotation_mapping( # type: ignore[misc]
Key identifying the source distribution.
target
Key identifying the target distribution.
batch_size
Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`.
Larger value will require more memory.
If :obj:`None`, the entire cost matrix will be materialized.
cell_transition_kwargs
Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.
Expand All @@ -327,7 +333,9 @@ def annotation_mapping( # type: ignore[misc]
target=target,
key=self.batch_key,
forward=forward,
batch_size=batch_size,
cell_transition_kwargs=cell_transition_kwargs,
**kwargs,
)

@property
Expand Down Expand Up @@ -626,7 +634,9 @@ def annotation_mapping( # type: ignore[misc]
source: K,
target: Union[K, str] = "tgt",
forward: bool = False,
batch_size: Optional[int] = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Mapping[str, Any],
) -> pd.DataFrame:
"""Transfer annotations between distributions.
Expand All @@ -648,6 +658,10 @@ def annotation_mapping( # type: ignore[misc]
Key identifying the source distribution.
target
Key identifying the target distribution.
batch_size
Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`.
Larger value will require more memory.
If :obj:`None`, the entire cost matrix will be materialized.
cell_transition_kwargs
Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.
Expand All @@ -663,7 +677,9 @@ def annotation_mapping( # type: ignore[misc]
forward=forward,
key=self.batch_key,
other_adata=self.adata_sc,
batch_size=batch_size,
cell_transition_kwargs=cell_transition_kwargs,
**kwargs,
)

@property
Expand Down
8 changes: 8 additions & 0 deletions src/moscot/problems/time/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ def annotation_mapping(
forward: bool,
source: K,
target: K,
batch_size: Optional[int] = None,
cell_transition_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Mapping[str, Any],
) -> pd.DataFrame:
"""Transfer annotations between distributions.
Expand All @@ -269,6 +271,10 @@ def annotation_mapping(
Key identifying the source distribution.
target
Key identifying the target distribution.
batch_size
Number of rows/columns of the cost matrix to materialize during :meth:`push` or :meth:`pull`.
Larger value will require more memory.
If :obj:`None`, the entire cost matrix will be materialized.
cell_transition_kwargs
Keyword arguments for :meth:`cell_transition`, used only if ``mapping_mode = 'sum'``.
Expand All @@ -284,7 +290,9 @@ def annotation_mapping(
key=self._temporal_key,
forward=forward,
other_adata=None,
batch_size=batch_size,
cell_transition_kwargs=cell_transition_kwargs,
**kwargs,
)

def sankey(
Expand Down
10 changes: 8 additions & 2 deletions tests/problems/cross_modality/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,10 @@ def test_cell_transition_pipeline(
@pytest.mark.fast()
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
@pytest.mark.parametrize("batch_size", [3, 7, None])
@pytest.mark.parametrize("problem_kind", ["cross_modality"])
def test_annotation_mapping(
self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, gt_tm_annotation
self, adata_anno: Tuple[AnnData, AnnData], forward: bool, mapping_mode, batch_size, gt_tm_annotation
):
adata_src, adata_tgt = adata_anno
tp = TranslationProblem(adata_src, adata_tgt)
Expand All @@ -122,7 +123,12 @@ def test_annotation_mapping(
tp[problem_keys].set_solution(MockSolverOutput(gt_tm_annotation), overwrite=True)
annotation_label = "celltype1" if forward else "celltype2"
result = tp.annotation_mapping(
mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source="src", target="tgt"
mapping_mode=mapping_mode,
annotation_label=annotation_label,
forward=forward,
source="src",
target="tgt",
batch_size=batch_size,
)
if forward:
expected_result = (
Expand Down
8 changes: 6 additions & 2 deletions tests/problems/space/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def test_cell_transition_pipeline(self, adata_space_rotate: AnnData, forward: bo
@pytest.mark.fast()
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
@pytest.mark.parametrize("batch_size", [3, 7, None])
@pytest.mark.parametrize("problem_kind", ["alignment"])
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation):
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation):
ap = AlignmentProblem(adata=adata_anno)
ap = ap.prepare(batch_key="batch", joint_attr={"attr": "X"})
problem_keys = ("0", "1")
Expand All @@ -110,6 +111,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo
source="0",
target="1",
forward=forward,
batch_size=batch_size,
)
if forward:
expected_result = (
Expand Down Expand Up @@ -207,8 +209,9 @@ def test_cell_transition_pipeline(self, adata_mapping: AnnData, forward: bool, n
@pytest.mark.fast()
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
@pytest.mark.parametrize("batch_size", [3, 7, None])
@pytest.mark.parametrize("problem_kind", ["mapping"])
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation):
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation):
adataref, adatasp = adata_anno
mp = MappingProblem(adataref, adatasp)
mp = mp.prepare(sc_attr={"attr": "obsm", "key": "X_pca"}, joint_attr={"attr": "X"})
Expand All @@ -221,6 +224,7 @@ def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mo
annotation_label=annotation_label,
source="src",
forward=forward,
batch_size=batch_size,
)
if not forward:
expected_result = adataref.uns["expected_max1"] if mapping_mode == "max" else adataref.uns["expected_sum1"]
Expand Down
10 changes: 8 additions & 2 deletions tests/problems/time/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,22 @@ def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward
@pytest.mark.fast()
@pytest.mark.parametrize("forward", [True, False])
@pytest.mark.parametrize("mapping_mode", ["max", "sum"])
@pytest.mark.parametrize("batch_size", [3, 7, None])
@pytest.mark.parametrize("problem_kind", ["temporal"])
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, gt_tm_annotation):
def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation):
problem = TemporalProblem(adata_anno)
problem_keys = (0, 1)
problem = problem.prepare(time_key="day", joint_attr="X_pca")
assert set(problem.problems.keys()) == {problem_keys}
problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation)
annotation_label = "celltype1" if forward else "celltype2"
result = problem.annotation_mapping(
mapping_mode=mapping_mode, annotation_label=annotation_label, forward=forward, source=0, target=1
mapping_mode=mapping_mode,
annotation_label=annotation_label,
forward=forward,
source=0,
target=1,
batch_size=batch_size,
)
if forward:
expected_result = (
Expand Down

0 comments on commit 4617c94

Please sign in to comment.