Skip to content

Commit

Permalink
Fix ManagedCollisionEmbeddingCollection not recognised as unpooled if…
Browse files Browse the repository at this point in the history
… it is a submodule (#1728)

Summary:
Pull Request resolved: #1728

When `ManagedCollisionEmbeddingCollection` is part of a module's submodules, the parent module is not recognised as unpooled. Doesn't matter in practice, since EC is always a submodule of MC_EC, but it is more correct this way

Reviewed By: henrylhtsang

Differential Revision: D53920749

fbshipit-source-id: 9830c4b557a6da3fd54b88624c85c64a56c72a7a
  • Loading branch information
sarckk authored and facebook-github-bot committed Mar 6, 2024
1 parent 0b186eb commit 5c343d9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchrec/distributed/planner/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def module_pooled(module: nn.Module, sharding_option_name: str) -> bool:

for submodule in module.modules():
if isinstance(submodule, EmbeddingCollectionInterface) or isinstance(
module, ManagedCollisionEmbeddingCollection
submodule, ManagedCollisionEmbeddingCollection
):
for name, _ in submodule.named_parameters():
if sharding_option_name in name:
Expand Down

0 comments on commit 5c343d9

Please sign in to comment.