Skip to content

Commit

Permalink
improve the tests to also use other rank solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
selmanozleyen committed Sep 23, 2024
1 parent 01c89a2 commit 49c8bbc
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions tests/problems/base/test_general_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,28 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData):

assert isinstance(prob.solution, BaseDiscreteSolverOutput)

@pytest.mark.parametrize("kind", ["linear", "quadratic"])
def test_unrecognized_args(self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"]):
@pytest.mark.parametrize(
("kind", "rank"),
[
("linear", -1),
("linear", 5),
("quadratic", -1),
("quadratic", 5),
],
)
def test_unrecognized_args(
self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"], rank: int
):
prob = OTProblem(adata_x, adata_y)
data = {
"xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"},
}
if kind == "quadratic":
if "quadratic" in kind:
data["x"] = {"attr": "X"}
data["y"] = {"attr": "X"}

with pytest.raises(TypeError):
prob.prepare(**data).solve(epsilon=5e-1, dummy=42)
prob.prepare(**data).solve(epsilon=5e-1, rank=rank, dummy=42)

@pytest.mark.fast
def test_output(self, adata_x: AnnData, x: Geom_t):
Expand Down

0 comments on commit 49c8bbc

Please sign in to comment.