From 3c23d82477a25d383da2ae8f0f557c19fc044ca9 Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Mon, 8 Jan 2024 07:40:23 -0800 Subject: [PATCH] Add EC ZCH to default sharder (#1610) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1610 Added the EC ZCH module to the get_default_sharders method Reviewed By: henrylhtsang Differential Revision: D52541735 fbshipit-source-id: da055bbb5c3eee517f13e9156327218249810869 --- torchrec/distributed/sharding_plan.py | 2 ++ torchrec/distributed/tests/test_sharding_plan.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 7c4a96ecd..d0f96fbed 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -21,6 +21,7 @@ FeatureProcessedEmbeddingBagCollectionSharder, ) from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder +from torchrec.distributed.mc_embedding import ManagedCollisionEmbeddingCollectionSharder from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ) @@ -47,6 +48,7 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: cast(ModuleSharder[nn.Module], QuantEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], QuantEmbeddingCollectionSharder()), cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingBagCollectionSharder()), + cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingCollectionSharder()), ] diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index 2ac19715e..a075b0b41 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -23,6 +23,7 @@ FusedEmbeddingBagCollectionSharder, get_module_to_default_sharders, ManagedCollisionEmbeddingBagCollectionSharder, + ManagedCollisionEmbeddingCollectionSharder, ParameterShardingGenerator, QuantEmbeddingBagCollectionSharder, QuantEmbeddingCollectionSharder, @@ -52,7 +53,10 @@ ) from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection from torchrec.modules.fused_embedding_modules import FusedEmbeddingBagCollection -from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection +from torchrec.modules.mc_embedding_modules import ( + ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, +) from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, @@ -710,6 +714,7 @@ def test_module_to_default_sharders(self) -> None: QuantEmbeddingBagCollection, QuantEmbeddingCollection, ManagedCollisionEmbeddingBagCollection, + ManagedCollisionEmbeddingCollection, ], ) self.assertIsInstance( @@ -738,3 +743,8 @@ def test_module_to_default_sharders(self) -> None: default_sharder_map[ManagedCollisionEmbeddingBagCollection], ManagedCollisionEmbeddingBagCollectionSharder, ) + + self.assertIsInstance( + default_sharder_map[ManagedCollisionEmbeddingCollection], + ManagedCollisionEmbeddingCollectionSharder, + )