From 4aab671a0c2293354da3a6e5012630675a9b6ce3 Mon Sep 17 00:00:00 2001 From: Dominik Klein Date: Mon, 9 Oct 2023 14:31:32 +0200 Subject: [PATCH] adapt test --- tests/backends/ott/test_backend.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index ee75bd200..173770f44 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -129,12 +129,11 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f assert isinstance(solver.y, Geometry) np.testing.assert_allclose(gt.matrix, pred.transport_matrix, rtol=RTOL, atol=ATOL) - @pytest.mark.skip(reason="TODO") @pytest.mark.parametrize("rank", [-1, 7]) def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None: thresh, eps = 1e-2, 1e-2 if rank > -1: - gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)( + gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh, initializer="rank2")( QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) )