diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index f8b32106b..492b8e348 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -709,7 +709,9 @@ def __init__( ) self._remap_sharding_plan( - plan, self._global_rank, world_size // sharding_group_size + plan=plan, + rank=self._global_rank, + num_nodes=world_size // sharding_group_size, ) super().__init__( module, @@ -733,7 +735,7 @@ def sync(self, include_optimizer_state: bool = True) -> None: """ Syncs the DMP weights across the allreduce (inter) process group - This method is called after each forward pass to synchronize the weights of the sharded modules. + This method is called after each train step to synchronize the weights of the sharded modules. It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights, which averages the weights across all processes in the inter-process group. @@ -782,10 +784,10 @@ def _create_process_groups( replication process group, and allreduce process group. """ peer_matrix = [] - step = world_size // local_size + num_nodes = world_size // local_size for group_rank in range(world_size // local_size): - peers = [step * r + group_rank for r in range(local_size)] + peers = [num_nodes * r + group_rank for r in range(local_size)] peer_matrix.append(peers) mesh = DeviceMesh( @@ -805,7 +807,9 @@ def _create_process_groups( return mesh, sharding_pg, replica_pg - def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None: + def _remap_sharding_plan( + self, plan: ShardingPlan, rank: int, num_nodes: int + ) -> None: """ Remaps the sharding plan to the local replica process group ranks ShardingPlan is remapped inplace. @@ -816,22 +820,22 @@ def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None Args: plan (ShardingPlan): The original sharding plan. global_rank (int): The global rank of the current process. - step (int): The number of nodes. + num_nodes (int): The number of nodes. """ - group_start = rank % step + group_start = rank % num_nodes for key in plan.plan: # pyre-ignore[16] for _, param_sharding in plan.plan[key].items(): new_ranks = [] for shard_rank in param_sharding.ranks: - new_ranks.append(shard_rank * step + group_start) + new_ranks.append(shard_rank * num_nodes + group_start) param_sharding.ranks = new_ranks if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): shards = param_sharding.sharding_spec.shards if shards is not None: for shard in shards: - shard_rank = shard.placement._rank * step + group_start + shard_rank = shard.placement._rank * num_nodes + group_start shard.placement = _remote_device( f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}" )