diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 7f3878161..b267a3278 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -311,8 +311,12 @@ def check_dtype(dtype, margin, no_processing=False): gc.collect() +@pytest.mark.skipif( + torch.backends.mps.is_available() or not torch.cuda.is_available(), + reason="some operations unsupported by MPS: https://github.com/pytorch/pytorch/issues/77754 or no GPU", +) @pytest.mark.parametrize("dtype", [torch.float64, torch.float32]) -def test_dtypes(dtype): +def test_dtype_float(dtype): check_dtype(dtype, margin=5e-4)