diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py index dc25a7c28..d243fc255 100644 --- a/torchrec/sparse/tests/test_tensor_dict.py +++ b/torchrec/sparse/tests/test_tensor_dict.py @@ -11,14 +11,20 @@ import unittest import torch +from hypothesis import given, settings, strategies as st, Verbosity from tensordict import TensorDict from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.sparse.tensor_dict import maybe_td_to_kjt -from torchrec.sparse.tests.utils import repeat_test class TestTensorDIct(unittest.TestCase): - @repeat_test(device_str=["cpu", "cuda", "meta"]) + @given(device_str=st.sampled_from(["cpu", "cuda", "meta"])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) def test_kjt_input(self, device_str: str) -> None: device = torch.device(device_str) values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) @@ -30,7 +36,13 @@ def test_kjt_input(self, device_str: str) -> None: features = maybe_td_to_kjt(kjt) self.assertEqual(features, kjt) - @repeat_test(device_str=["cpu", "cuda", "meta"]) + @given(device_str=st.sampled_from(["cpu", "cuda", "meta"])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + # pyre-ignore[56] + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) def test_td_kjt(self, device_str: str) -> None: device = torch.device(device_str) values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)