diff --git a/smdebug/pytorch/hook.py b/smdebug/pytorch/hook.py index 1eeb0636f..ebcef15c1 100644 --- a/smdebug/pytorch/hook.py +++ b/smdebug/pytorch/hook.py @@ -3,6 +3,7 @@ # Third Party import torch import torch.distributed as dist +from torch.nn.parallel.data_parallel import DataParallel # First Party from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys @@ -197,6 +198,14 @@ def register_hook(self, module): # for compatibility with ZCC patches which call this self.register_module(module) + @staticmethod + def _add_module_name(module, module_name): + if isinstance(module, DataParallel): + module.module._module_name = module_name + else: + module._module_name = module_name + return module + def register_module(self, module): """ This function registers the forward hook. If user wants to register the hook @@ -215,9 +224,9 @@ def register_module(self, module): for name, submodule in module.named_modules(): assert submodule not in self.module_set, f"Don't register module={module} twice" - submodule._module_name = name + Hook._add_module_name(submodule, name) self.module_set.add(submodule) - module._module_name = module._get_name() + Hook._add_module_name(module, module._get_name()) self.module_set.add(module) # Use `forward_pre_hook` for the entire net