diff --git a/examples/retrieval/tests/test_two_tower_retrieval.py b/examples/retrieval/tests/test_two_tower_retrieval.py index f4d95b7a6..77952e0fa 100644 --- a/examples/retrieval/tests/test_two_tower_retrieval.py +++ b/examples/retrieval/tests/test_two_tower_retrieval.py @@ -21,8 +21,8 @@ class InferTest(unittest.TestCase): @skip_if_asan # pyre-ignore[56] @unittest.skipIf( - not torch.cuda.is_available(), - "this test requires a GPU", + torch.cuda.device_count() <= 1, + "Not enough GPUs, this test requires at least two GPUs", ) def test_infer_function(self) -> None: infer(