From fad07999e762e2cf934d3f2f8092007a12d0f58d Mon Sep 17 00:00:00 2001 From: Yong Hoon Shin Date: Tue, 5 Mar 2024 06:51:24 -0800 Subject: [PATCH] Fix ManagedCollisionEmbeddingCollection not recognised as unpooled if it is a submodule (#1728) Summary: When `ManagedCollisionEmbeddingCollection` is part of a module's submodules, the parent module is not recognised as unpooled. Doesn't matter in practice, since EC is always a submodule of MC_EC, but it is more correct this way Reviewed By: henrylhtsang Differential Revision: D53920749 --- torchrec/distributed/planner/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 02f786d79..862943537 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -371,7 +371,7 @@ def module_pooled(module: nn.Module, sharding_option_name: str) -> bool: for submodule in module.modules(): if isinstance(submodule, EmbeddingCollectionInterface) or isinstance( - module, ManagedCollisionEmbeddingCollection + submodule, ManagedCollisionEmbeddingCollection ): for name, _ in submodule.named_parameters(): if sharding_option_name in name: