diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 26ec9ae5f..5f16efc1b 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -96,6 +96,14 @@ except OSError: pass +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + + logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 84e033a31..874834748 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -98,6 +98,13 @@ except OSError: pass +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 307d66639..9a1878361 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -21,6 +21,14 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +try: + from tensordict import TensorDict +except ImportError: + + class TensorDict: + pass + + @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 15952bfa5..8468c9977 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -49,9 +49,12 @@ # OSS try: - pass + from tensordict import TensorDict except ImportError: - pass + + class TensorDict: + pass + logger: logging.Logger = logging.getLogger()