diff --git a/torchrec/modules/itep_modules.py b/torchrec/modules/itep_modules.py index 3a626887f..14a479b2e 100644 --- a/torchrec/modules/itep_modules.py +++ b/torchrec/modules/itep_modules.py @@ -12,6 +12,7 @@ import torch from torch import nn +from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_types import ShardedEmbeddingTable from torchrec.modules.embedding_modules import reorder_inverse_indices from torchrec.sparse.jagged_tensor import _pin_and_move, _to_offsets, KeyedJaggedTensor @@ -200,6 +201,8 @@ def init_itep_state(self) -> None: # Iterate over all tables # pyre-ignore for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module for emb in lookup._emb_modules: emb_tables: List[ShardedEmbeddingTable] = emb._config.embedding_tables @@ -283,6 +286,8 @@ def reset_weight_momentum( if self.lookups is not None: # pyre-ignore for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module for emb in lookup._emb_modules: emb_tables: List[ShardedEmbeddingTable] = ( emb._config.embedding_tables @@ -322,6 +327,8 @@ def flush_uvm_cache(self) -> None: if self.lookups is not None: # pyre-ignore for lookup in self.lookups: + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module for emb in lookup._emb_modules: emb.emb_module.flush() emb.emb_module.reset_cache_states()