Skip to content

Commit

Permalink
Fix ManagedCollisionEmbeddingCollection not recognised as unpooled if…
Browse files Browse the repository at this point in the history
… it is a submodule (pytorch#1728)

Summary:

When `ManagedCollisionEmbeddingCollection` is part of a module's submodules, the parent module is not recognised as unpooled. Doesn't matter in practice, since EC is always a submodule of MC_EC, but it is more correct this way

Reviewed By: henrylhtsang

Differential Revision: D53920749
  • Loading branch information
sarckk authored and facebook-github-bot committed Mar 5, 2024
1 parent e96e10d commit fad0799
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def module_pooled(module: nn.Module, sharding_option_name: str) -> bool:

for submodule in module.modules():
if isinstance(submodule, EmbeddingCollectionInterface) or isinstance(
module, ManagedCollisionEmbeddingCollection
submodule, ManagedCollisionEmbeddingCollection
):
for name, _ in submodule.named_parameters():
if sharding_option_name in name:
Expand Down

0 comments on commit fad0799

Please sign in to comment.