From 2b5db3f1031c02643c764ee06f3fea0290ddcf32 Mon Sep 17 00:00:00 2001 From: Dark Knight Date: Wed, 4 Dec 2024 19:56:35 -0800 Subject: [PATCH] Revert D66465376 Summary: This diff reverts D66465376 seems causing NCCL error and regressing peak mem for a TorchBench test Reviewed By: TroyGarden Differential Revision: D66794877 --- torchrec/distributed/embedding.py | 8 -------- torchrec/distributed/embeddingbag.py | 7 ------- torchrec/modules/embedding_modules.py | 8 -------- torchrec/sparse/jagged_tensor.py | 7 ++----- 4 files changed, 2 insertions(+), 28 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 5f16efc1b..26ec9ae5f 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -96,14 +96,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - - logger: logging.Logger = logging.getLogger(__name__) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index e24a695d9..23e9cbbaa 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -99,13 +99,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 9a1878361..307d66639 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -21,14 +21,6 @@ from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - - @torch.fx.wrap def reorder_inverse_indices( inverse_indices: Optional[Tuple[List[str], torch.Tensor]], diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index fa28309e3..4b5359f0d 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -49,12 +49,9 @@ # OSS try: - from tensordict import TensorDict + pass except ImportError: - - class TensorDict: - pass - + pass logger: logging.Logger = logging.getLogger()