From 655c12a1c5a1bb923f3b1f4b7eac9de9ee1b14df Mon Sep 17 00:00:00 2001 From: leej3 Date: Tue, 21 May 2024 12:50:10 +0100 Subject: [PATCH] skip tests when mps not functional --- tests/ignite/distributed/utils/test_serial.py | 5 +++-- tests/ignite/engine/test_create_supervised.py | 8 +++++--- tests/utils_for_tests.py | 12 ++++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) create mode 100644 tests/utils_for_tests.py diff --git a/tests/ignite/distributed/utils/test_serial.py b/tests/ignite/distributed/utils/test_serial.py index fdbf26e83608..d61fa8b75695 100644 --- a/tests/ignite/distributed/utils/test_serial.py +++ b/tests/ignite/distributed/utils/test_serial.py @@ -12,13 +12,14 @@ _test_distrib_new_group, _test_sync, ) +from ....utils_for_tests import is_mps_available_and_functional def test_no_distrib(capsys): assert idist.backend() is None if torch.cuda.is_available(): assert idist.device().type == "cuda" - elif _torch_version_le_112 and torch.backends.mps.is_available(): + elif _torch_version_le_112 and is_mps_available_and_functional(): assert idist.device().type == "mps" else: assert idist.device().type == "cpu" @@ -41,7 +42,7 @@ def test_no_distrib(capsys): assert "ignite.distributed.utils INFO: backend: None" in out[-1] if torch.cuda.is_available(): assert "ignite.distributed.utils INFO: device: cuda" in out[-1] - elif _torch_version_le_112 and torch.backends.mps.is_available(): + elif _torch_version_le_112 and is_mps_available_and_functional(): assert "ignite.distributed.utils INFO: device: mps" in out[-1] else: assert "ignite.distributed.utils INFO: device: cpu" in out[-1] diff --git a/tests/ignite/engine/test_create_supervised.py b/tests/ignite/engine/test_create_supervised.py index 31ca43f4bbf7..2ecf7438b458 100644 --- a/tests/ignite/engine/test_create_supervised.py +++ b/tests/ignite/engine/test_create_supervised.py @@ -25,6 +25,8 @@ ) from ignite.metrics import MeanSquaredError +from ...utils_for_tests import is_mps_available_and_functional # type: ignore + class DummyModel(torch.nn.Module): def __init__(self, output_as_list=False): @@ -485,7 +487,7 @@ def test_create_supervised_trainer_on_cuda(): _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device) -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_trainer_on_mps(): model_device = trainer_device = "mps" _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device) @@ -666,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu(): _test_mocked_supervised_evaluator(evaluator_device="cuda") -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps(): model_device = evaluator_device = "mps" _test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) _test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device) -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") def test_create_supervised_evaluator_on_mps_with_model_on_cpu(): _test_create_supervised_evaluator(evaluator_device="mps") _test_mocked_supervised_evaluator(evaluator_device="mps") diff --git a/tests/utils_for_tests.py b/tests/utils_for_tests.py new file mode 100644 index 000000000000..3076d080bfb9 --- /dev/null +++ b/tests/utils_for_tests.py @@ -0,0 +1,12 @@ +import torch + + +def is_mps_available_and_functional(): + if not torch.backends.mps.is_available(): + return False + try: + # Try to allocate a small tensor on the MPS device + torch.tensor([1.0], device="mps") + return True + except RuntimeError: + return False