From f15c4e345bac84b9ab575f837d78d3f0804efb0d Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Dec 2024 18:10:39 +0100 Subject: [PATCH] fix test --- tests/backends/ott/test_backend.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index b268bdb3..9a0a0410 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -350,10 +350,7 @@ def test_pull( b, ndim = (b, b.shape[1]) if batched else (b[:, 0], None) xx, yy = xy solver = solver_t() - additional_kwargs = {"alpha": 1.0} if xy is None else {} - out = solver( - a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), **additional_kwargs - ) + out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), alpha=0.5) p = out.pull(b, scale_by_marginals=False) assert isinstance(out, BaseDiscreteSolverOutput)