diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index d7d785893..869b0ea7f 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -30,18 +30,28 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): assert isinstance(prob.solution, BaseDiscreteSolverOutput) - @pytest.mark.parametrize("kind", ["linear", "quadratic"]) - def test_unrecognized_args(self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"]): + @pytest.mark.parametrize( + ("kind", "rank"), + [ + ("linear", -1), + ("linear", 5), + ("quadratic", -1), + ("quadratic", 5), + ], + ) + def test_unrecognized_args( + self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"], rank: int + ): prob = OTProblem(adata_x, adata_y) data = { "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, } - if kind == "quadratic": + if "quadratic" in kind: data["x"] = {"attr": "X"} data["y"] = {"attr": "X"} with pytest.raises(TypeError): - prob.prepare(**data).solve(epsilon=5e-1, dummy=42) + prob.prepare(**data).solve(epsilon=5e-1, rank=rank, dummy=42) @pytest.mark.fast def test_output(self, adata_x: AnnData, x: Geom_t):