Skip to content

Commit

Permalink
increase allclose tolerance for test_inv
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito committed Dec 11, 2024
1 parent c7cdf34 commit dc2ab84
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,8 @@ def test_inv(self):
self.assertTrue(ht.allclose(ainv, ares, atol=1e-6))

# pivoting row change
dtype = ht.float if self.is_mps else ht.double
atol = 1e-6 if dtype == ht.float else 1e-12
dtype = ht.floa32 if self.is_mps else ht.float64
atol = 1e-6 if dtype == ht.float32 else 1e-12

ares = ht.array([[-1, 0, 2], [2, 0, -1], [-6, 3, 0]], dtype=dtype, split=0) / 3.0
a = ht.array([[1, 2, 0], [2, 4, 1], [2, 1, 0]], dtype=dtype, split=0)
Expand Down Expand Up @@ -318,7 +318,7 @@ def test_inv(self):
a = ht.random.random((20, 20), dtype=dtype, split=0)
ainv = ht.linalg.inv(a)
i = ht.eye(a.shape, split=0, dtype=a.dtype)
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-5 if self.is_mps else atol))
self.assertTrue(ht.allclose(a @ ainv, i, atol=1e-5 if self.is_mps else atol * 10))

with self.assertRaises(RuntimeError):
ht.linalg.inv(ht.array([1, 2, 3], split=0))
Expand Down

0 comments on commit dc2ab84

Please sign in to comment.