diff --git a/heat/core/tests/test_vmap.py b/heat/core/tests/test_vmap.py index 8fd1f4734..7cbdacdff 100644 --- a/heat/core/tests/test_vmap.py +++ b/heat/core/tests/test_vmap.py @@ -1,5 +1,6 @@ import heat as ht import torch +import os from .test_suites.basic_test import TestCase @@ -79,8 +80,6 @@ def func(x0, m=1, scale=2): vfunc_torch = torch.vmap(func, (0,), (0,)) y0_torch = vfunc_torch(x0_torch, m=2, scale=3) - print(y0.resplit(None).larray, y0_torch) - self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) def test_vmap_with_chunks(self): @@ -123,7 +122,8 @@ def func(x0, x1, k=2, scale=1e-2): y0_torch, y1_torch = vfunc_torch(x0_torch, x1_torch, k=5, scale=2.2) self.assertTrue(torch.allclose(y0.resplit(None).larray, y0_torch)) - self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch)) + tol = 1e-4 + self.assertTrue(torch.allclose(y1.resplit(None).larray, y1_torch, atol=tol, rtol=tol)) def test_vmap_catch_errors(self): # not a callable