diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index 02f786d79..862943537 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -371,7 +371,7 @@ def module_pooled(module: nn.Module, sharding_option_name: str) -> bool: for submodule in module.modules(): if isinstance(submodule, EmbeddingCollectionInterface) or isinstance( - module, ManagedCollisionEmbeddingCollection + submodule, ManagedCollisionEmbeddingCollection ): for name, _ in submodule.named_parameters(): if sharding_option_name in name: