diff --git a/tests/test_relaxation.py b/tests/test_relaxation.py index abe78617..adfbe35d 100644 --- a/tests/test_relaxation.py +++ b/tests/test_relaxation.py @@ -41,7 +41,7 @@ def test_relaxation(algorithm: Literal["legacy", "fast"]): no_cuda = mark.skipif(not torch.cuda.is_available(), reason="No CUDA device") -no_mps = mark.skipif(not hasattr(torch.backends, "mps"), reason="No MPS device") +no_mps = mark.skipif(not torch.backends.mps.is_available(), reason="No MPS device") @mark.parametrize(