diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index 61f5641f3..ab2c6a103 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -16,6 +16,7 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation from torch import fx, nn from torch.nn.modules.module import _addindent +from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.types import ( get_tensor_size_bytes, ModuleSharder, @@ -337,6 +338,8 @@ def prefetch( """ for feature, emb_lookup in zip(dist_input, self._lookups): + while isinstance(emb_lookup, DistributedDataParallel): + emb_lookup = emb_lookup.module emb_lookup.prefetch(sparse_features=feature, forward_stream=forward_stream) def extra_repr(self) -> str: