From eda80ce805deeb7f404ce08972b2d9247a6fb1ae Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Sat, 14 Dec 2024 16:33:20 -0800 Subject: [PATCH] Re-land D66465376 (#2637) Summary: Re-land diff D66465376 NOTE: use jit.ignore on the forward function to get rid of jit script error with `TensorDict` ``` def test_td_scripting(self) -> None: class TestModule(torch.nn.Module): torch.jit.ignore # <----- test fails without this ignore def forward(self, x: Union[TensorDict, KeyedJaggedTensor]) -> torch.Tensor: if isinstance(x, TensorDict): keys = list(x.keys()) return torch.cat([x[key]._values for key in keys], dim=0) else: return x._values m = TestModule() gm = torch.fx.symbolic_trace(m) jm = torch.jit.script(gm) values = torch.tensor([0, 1, 2, 3, 2, 3, 4]) kjt = KeyedJaggedTensor.from_offsets_sync( keys=["f1", "f2", "f3"], values=values, offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]), ) torch.testing.assert_allclose(jm(kjt), values) ``` Differential Revision: D66460392 --- torchrec/distributed/embedding.py | 8 ++++++++ torchrec/distributed/embeddingbag.py | 7 +++++++ torchrec/modules/embedding_modules.py | 8 ++++++++ torchrec/sparse/jagged_tensor.py | 7 +++++-- 4 files changed, 28 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 26ec9ae5f..5f16efc1b 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -96,6 +96,14 @@ except OSError: pass +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + + logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 84e033a31..874834748 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -98,6 +98,13 @@ except OSError: pass +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 307d66639..9a1878361 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -21,6 +21,14 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + + @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 15952bfa5..8468c9977 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -49,9 +49,12 @@ # OSS try: - pass + from tensordict import TensorDict except ImportError: - pass + + class TensorDict: + pass + logger: logging.Logger = logging.getLogger()