From 8f8a88517dc1e56e0a7124942ea31aeaa5507e2a Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Fri, 17 Jan 2025 11:59:56 -0800 Subject: [PATCH] fix test in OSS env without CUDA device (#2688) Summary: # context * to fix OSS CPU test failure due to lack of CUDA device. Reviewed By: dstaay-fb Differential Revision: D68340773 --- torchrec/sparse/tests/test_tensor_dict.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/torchrec/sparse/tests/test_tensor_dict.py b/torchrec/sparse/tests/test_tensor_dict.py index dc25a7c28..13a532c9a 100644 --- a/torchrec/sparse/tests/test_tensor_dict.py +++ b/torchrec/sparse/tests/test_tensor_dict.py @@ -11,16 +11,20 @@ import unittest import torch +from hypothesis import given, settings, strategies as st, Verbosity from tensordict import TensorDict from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.sparse.tensor_dict import maybe_td_to_kjt -from torchrec.sparse.tests.utils import repeat_test class TestTensorDIct(unittest.TestCase): - @repeat_test(device_str=["cpu", "cuda", "meta"]) - def test_kjt_input(self, device_str: str) -> None: - device = torch.device(device_str) + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + @given(device=st.sampled_from([torch.device(d) for d in ["cpu", "cuda", "meta"]])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_kjt_input(self, device:torch.device) -> None: values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) kjt = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], @@ -30,9 +34,13 @@ def test_kjt_input(self, device_str: str) -> None: features = maybe_td_to_kjt(kjt) self.assertEqual(features, kjt) - @repeat_test(device_str=["cpu", "cuda", "meta"]) - def test_td_kjt(self, device_str: str) -> None: - device = torch.device(device_str) + @unittest.skipIf( + torch.cuda.device_count() <= 0, + "CUDA is not available", + ) + @given(device=st.sampled_from([torch.device(d) for d in ["cpu", "cuda", "meta"]])) + @settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None) + def test_td_kjt(self, device: torch.device) -> None: values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device) lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device) data = {