From 24a9bd8ac50c33e7fd636b59da28ad1caba8d9df Mon Sep 17 00:00:00 2001 From: marmaduke woodman Date: Tue, 20 Feb 2024 10:22:16 +0100 Subject: [PATCH] raise atol for spmv tests --- vbjax/tests/test_sparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vbjax/tests/test_sparse.py b/vbjax/tests/test_sparse.py index c8e534e..c85ebab 100644 --- a/vbjax/tests/test_sparse.py +++ b/vbjax/tests/test_sparse.py @@ -20,7 +20,7 @@ def _test_spmv(spmv, A, n): numpy.testing.assert_allclose(jb, nb, 1e-4, 1e-6) # now its gradient - jax.test_util.check_grads(spmv, (jx,), order=1, modes=('rev',)) + jax.test_util.check_grads(spmv, (jx,), order=1, modes=('rev',), atol=0.02, rtol=0.002) def test_csr_scipy():