Skip to content

Commit

Permalink
fix test in OSS env without CUDA device (#2688)
Browse files Browse the repository at this point in the history
Summary:

# context
* to fix OSS CPU test failure due to lack of CUDA device.

Reviewed By: dstaay-fb

Differential Revision: D68340773
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jan 17, 2025
1 parent 33168a1 commit 8f8a885
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions torchrec/sparse/tests/test_tensor_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@
import unittest

import torch
from hypothesis import given, settings, strategies as st, Verbosity
from tensordict import TensorDict
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
from torchrec.sparse.tests.utils import repeat_test


class TestTensorDIct(unittest.TestCase):
@repeat_test(device_str=["cpu", "cuda", "meta"])
def test_kjt_input(self, device_str: str) -> None:
device = torch.device(device_str)
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
@given(device=st.sampled_from([torch.device(d) for d in ["cpu", "cuda", "meta"]]))
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_kjt_input(self, device:torch.device) -> None:
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
kjt = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
Expand All @@ -30,9 +34,13 @@ def test_kjt_input(self, device_str: str) -> None:
features = maybe_td_to_kjt(kjt)
self.assertEqual(features, kjt)

@repeat_test(device_str=["cpu", "cuda", "meta"])
def test_td_kjt(self, device_str: str) -> None:
device = torch.device(device_str)
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"CUDA is not available",
)
@given(device=st.sampled_from([torch.device(d) for d in ["cpu", "cuda", "meta"]]))
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
def test_td_kjt(self, device: torch.device) -> None:
values = torch.tensor([0, 1, 2, 3, 2, 3, 4], device=device)
lengths = torch.tensor([2, 0, 1, 1, 1, 2], device=device)
data = {
Expand Down

0 comments on commit 8f8a885

Please sign in to comment.