From 9a9ec9b8c235cc581db603227af18c2eaa23ca8f Mon Sep 17 00:00:00 2001
From: Huanyu He <hhy@meta.com>
Date: Sat, 18 Jan 2025 00:22:11 -0800
Subject: [PATCH] fix hypothesis strategy that skips entire test without CUDA

Summary:
# context
* original implementation will skip the entire test set if there is no cuda available
* the actual intention is to loop all devices (cpu, meta, cuda) and only skip cuda if not available.

Differential Revision: D68373224
---
 torchrec/sparse/tests/test_tensor_dict.py | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py
index d243fc255..2fbcc0a66 100644
--- a/torchrec/sparse/tests/test_tensor_dict.py
+++ b/torchrec/sparse/tests/test_tensor_dict.py
@@ -17,14 +17,14 @@
 from torchrec.sparse.tensor_dict import maybe_td_to_kjt
 
 
-class TestTensorDIct(unittest.TestCase):
-    @given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
-    @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
+class TestTensorDict(unittest.TestCase):
     # pyre-ignore[56]
-    @unittest.skipIf(
-        torch.cuda.device_count() <= 0,
-        "CUDA is not available",
+    @given(
+        device_str=st.sampled_from(
+            ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
+        )
     )
+    @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
     def test_kjt_input(self, device_str: str) -> None:
         device = torch.device(device_str)
         values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
@@ -36,13 +36,13 @@ def test_kjt_input(self, device_str: str) -> None:
         features = maybe_td_to_kjt(kjt)
         self.assertEqual(features, kjt)
 
-    @given(device_str=st.sampled_from(["cpu", "cuda", "meta"]))
-    @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
     # pyre-ignore[56]
-    @unittest.skipIf(
-        torch.cuda.device_count() <= 0,
-        "CUDA is not available",
+    @given(
+        device_str=st.sampled_from(
+            ["cpu", "meta"] + (["cuda"] if torch.cuda.device_count() > 0 else [])
+        )
     )
+    @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
     def test_td_kjt(self, device_str: str) -> None:
         device = torch.device(device_str)
         values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)