From 519f19343b28582fb6479c5d7f3cb40cf279bba3 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Wed, 22 Jan 2025 13:36:38 -0800 Subject: [PATCH] simplify 2D parallel process group init (#2694) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2694 DeviceMesh and manual PG initialization was redundant code leading to more process groups created then needed. (2x as much) In this diff we update the init to use the process groups created by the DeviceMesh init instead. Reviewed By: carlbunny, TroyGarden Differential Revision: D68495749 fbshipit-source-id: 85123a5f43f0e1c55e50e5fd52b6dbc0d2c62107 --- torchrec/distributed/model_parallel.py | 35 +++++++------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 5cbd2429b..f8b32106b 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -770,7 +770,7 @@ def _create_process_groups( ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: """ Creates process groups for sharding and replication, the process groups - are created in the same exact order on all ranks as per `dist.new_group` API. + are created using the DeviceMesh API. Args: global_rank (int): The global rank of the current process. @@ -781,37 +781,12 @@ def _create_process_groups( Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh, replication process group, and allreduce process group. """ - # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a peer_matrix = [] - sharding_pg, replica_pg = None, None step = world_size // local_size - my_group_rank = global_rank % step for group_rank in range(world_size // local_size): peers = [step * r + group_rank for r in range(local_size)] - backend = dist.get_backend(self._pg) - curr_pg = dist.new_group(backend=backend, ranks=peers) peer_matrix.append(peers) - if my_group_rank == group_rank: - logger.warning( - f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]" - ) - sharding_pg = curr_pg - assert sharding_pg is not None, "sharding_pg is not initialized!" - dist.barrier() - - my_inter_rank = global_rank // step - for inter_rank in range(local_size): - peers = [inter_rank * step + r for r in range(step)] - backend = dist.get_backend(self._pg) - curr_pg = dist.new_group(backend=backend, ranks=peers) - if my_inter_rank == inter_rank: - logger.warning( - f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]" - ) - replica_pg = curr_pg - assert replica_pg is not None, "replica_pg is not initialized!" - dist.barrier() mesh = DeviceMesh( device_type=self._device.type, @@ -819,6 +794,14 @@ def _create_process_groups( mesh_dim_names=("replicate", "shard"), ) logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") + sharding_pg = mesh.get_group(mesh_dim="shard") + logger.warning( + f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]" + ) + replica_pg = mesh.get_group(mesh_dim="replicate") + logger.warning( + f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]" + ) return mesh, sharding_pg, replica_pg