From e1b5edd871256161410a082a7d72290669704624 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Wed, 11 Dec 2024 07:58:59 -0800 Subject: [PATCH] add DTensor to VLE and fix 2D sharding group in EBC (#2626) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2626 Adding DTensor state dict to VLE, this is identical to EBC DTensor path. Future work of this diff is to consolidate the logic of EBC/PEA/VLE state dict into one parent class because of their significant similarities A revert diff removed the 2D sharding logic in embedding bag collection mistakenly: D66800554, this diff adds it back in Differential Revision: D65555595 fbshipit-source-id: ca12b8e833f9336e7c092b32d3955745806d9b56 --- torchrec/distributed/embeddingbag.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 06ad9f26e..84e033a31 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -65,6 +65,7 @@ QuantizedCommCodecs, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -938,7 +939,11 @@ def _initialize_torch_state(self) -> None: # noqa ShardedTensor._init_from_local_shards( local_shards, self._name_to_table_size[table_name], - process_group=self._env.process_group, + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), ) )