diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 22651f75a..ae8a0c782 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -13,7 +13,13 @@ import torch import torch.distributed as dist -from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torch.distributed._tensor import Shard +from torch.distributed.distributed_c10d import get_process_group_ranks +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) from torchrec.distributed.dist_data import ( KJTAllToAll, PooledEmbeddingsAllToAll, @@ -34,6 +40,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -44,6 +51,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -71,14 +79,26 @@ def __init__( ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg # pyre-ignore[16] + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank self._device = device self._need_pos = need_pos - intra_pg, cross_pg = intra_and_cross_node_pg( - device, backend=dist.get_backend(self._pg) - ) + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-fixme[6] + self._env, + device=device, + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) self._intra_pg: Optional[dist.ProcessGroup] = intra_pg self._cross_pg: Optional[dist.ProcessGroup] = cross_pg self._local_size: int = ( @@ -112,11 +132,23 @@ def _shard( world_size = self._world_size local_size = self._local_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] + peer_group = ( + # pyre-ignore [6] + get_process_group_ranks(self._pg) + if self._is_2D_parallel + else None + ) for info in sharding_infos: - # pyre-ignore [16] - table_node = info.param_sharding.ranks[0] // local_size + # Under 2D parallelism we transform rank to the logical ordering in a regular parallelism scheme + rank = ( + # pyre-ignore [16] + peer_group.index(info.param_sharding.ranks[0]) + if peer_group is not None + else info.param_sharding.ranks[0] + ) + table_node = rank // local_size # pyre-fixme [16] shards = info.param_sharding.sharding_spec.shards @@ -131,6 +163,21 @@ def _shard( ), ) + dtensor_metadata = None + if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = (Shard(0),) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + # to not pass onto TBE + info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + for rank in range( table_node * local_size, (table_node + 1) * local_size, @@ -154,6 +201,7 @@ def _shard( ), local_metadata=shards[rank_idx], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py index 4d0ce7b41..c76e8d4cf 100644 --- a/torchrec/distributed/tests/test_2d_sharding.py +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -402,3 +402,90 @@ def test_sharding_rw_2D( variable_batch_size=variable_batch_size, pooling=pooling, ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_twrw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_ROW_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + local_size=self.WORLD_SIZE_2D // 2, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + )