Skip to content

Commit

Permalink
add tests where alpha is used
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Apr 9, 2024
1 parent a93c7a0 commit fc857f9
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tests/problems/cross_modality/test_translation_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,23 @@ def test_prepare_external_star_policy(
assert tp[prob_key].xy.data_src.shape == tp[prob_key].xy.data_tgt.shape == (n_obs, xy_n_vars)

@pytest.mark.parametrize(
("epsilon", "rank", "initializer"),
[(1e-2, -1, None), (2, -1, "random"), (2, -1, "rank2"), (2, -1, None)],
("epsilon", "alpha", "rank", "initializer"),
[
(1e-2, 0.9, -1, None),
(2, 0.5, -1, "random"),
(2, 1.0, -1, "rank2"),
(2, 0.1, -1, None),
(2, 1.0, -1, None),
(1.3, 1.0, -1, "random"),
],
)
@pytest.mark.parametrize("src_attr", ["emb_src", {"attr": "obsm", "key": "emb_src"}])
@pytest.mark.parametrize("tgt_attr", ["emb_tgt", {"attr": "obsm", "key": "emb_tgt"}])
def test_solve_balanced(
self,
adata_translation_split: Tuple[AnnData, AnnData],
epsilon: float,
alpha: float,
rank: int,
src_attr: Mapping[str, str],
tgt_attr: Mapping[str, str],
Expand All @@ -103,8 +111,9 @@ def test_solve_balanced(
kwargs["initializer"] = initializer

tp = TranslationProblem(adata_src, adata_tgt)
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr)
tp = tp.solve(epsilon=epsilon, rank=rank, **kwargs)
joint_attr = None if alpha == 1.0 else {"attr": "obsm", "key": "X_pca"}
tp = tp.prepare(batch_key="batch", src_attr=src_attr, tgt_attr=tgt_attr, joint_attr=joint_attr)
tp = tp.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs)

for key, subsol in tp.solutions.items():
assert isinstance(subsol, BaseSolverOutput)
Expand Down

0 comments on commit fc857f9

Please sign in to comment.