diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index 31287975..b268bdb3 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -351,7 +351,9 @@ def test_pull( xx, yy = xy solver = solver_t() additional_kwargs = {"alpha": 1.0} if xy is None else {} - out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), **additional_kwargs) + out = solver( + a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), **additional_kwargs + ) p = out.pull(b, scale_by_marginals=False) assert isinstance(out, BaseDiscreteSolverOutput) @@ -404,5 +406,7 @@ def test_plot_errors_sink(self, x: Geom_t, y: Geom_t): out.plot_errors() def test_plot_errors_gw(self, x: Geom_t, y: Geom_t): - out = GWSolver(store_inner_errors=True)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0) + out = GWSolver(store_inner_errors=True)( + a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0 + ) out.plot_errors()