From 457854c90e54c9033a42bb3caed0cc09ea7eff64 Mon Sep 17 00:00:00 2001 From: Sergey Lebedev Date: Tue, 6 Aug 2024 20:41:00 -0700 Subject: [PATCH] Add support to multiple process groups by syncing across ranks (#151) Summary: Add support to multiple process groups by syncing across ranks. Pull Request resolved: https://github.com/facebookresearch/param/pull/151 Test Plan: /usr/local/fbcode/platform010/bin/mpirun -np 2 path-to/comm_replay.par --trace-path param_bench/fb/integration_tests/resnet-2gpu --trace-type et Reviewed By: briancoutinho Differential Revision: D60788539 Pulled By: shengfukevin --- et_replay/comm/backend/base_backend.py | 1 - .../comm/backend/pytorch_dist_backend.py | 36 +++++++++++++++---- et_replay/tools/comm_replay.py | 5 ++- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/et_replay/comm/backend/base_backend.py b/et_replay/comm/backend/base_backend.py index 75de50bc..1e367dca 100644 --- a/et_replay/comm/backend/base_backend.py +++ b/et_replay/comm/backend/base_backend.py @@ -174,7 +174,6 @@ def alloc_ones( ipTensor = ipTensor * scaleFactor return ipTensor - @abstractmethod def noop( self, collectiveArgs: collectiveArgsHolder = None, diff --git a/et_replay/comm/backend/pytorch_dist_backend.py b/et_replay/comm/backend/pytorch_dist_backend.py index cfd75068..7bf17810 100644 --- a/et_replay/comm/backend/pytorch_dist_backend.py +++ b/et_replay/comm/backend/pytorch_dist_backend.py @@ -3,8 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import json import logging import os + from itertools import cycle from time import sleep from typing import List, Optional @@ -1008,7 +1010,8 @@ def get_new_pg(self, group_ranks, backend): ranks=group_ranks, backend=backend ) else: - return dist.new_group(ranks=group_ranks, backend=backend) + pg = dist.new_group(ranks=group_ranks, backend=backend) + return pg def tensor_list_to_numpy(self, tensorList): if isinstance(tensorList, list): @@ -1070,9 +1073,30 @@ def initialize_backend( def initialize_groups(self, backend="gloo"): groups = {} world_size = self.get_world_size() + global_rank = self.get_global_rank() + + # sync pgs across ranks to fix hang with multiple comm groups + # because new_group() functions requires that all processes in the main group enter, + # even if they are not going to be members of the group. + # Assumption: pg_name is unique and consistent for all ranks + sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store) + sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks)) + torch.distributed.barrier() + group_ranks_sync = self.commsParams.groupRanks.copy() + for i in range(self.get_world_size()): + if i == global_rank: + continue + json_data = sync_store.get(str(i)) + + # convert pg_id in json_data to int + pg_id2group_ranks = { + int(pg_id): rank for pg_id, rank in json.loads(json_data).items() + } + + group_ranks_sync.update(pg_id2group_ranks) # create additional groups - for pg_id, group_ranks in self.commsParams.groupRanks.items(): + for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items(): if ( len(group_ranks) > world_size ): # this means that --auto-shrink is enabled, only use default pg @@ -1084,11 +1108,9 @@ def initialize_groups(self, backend="gloo"): pg = self.get_default_group() else: pg = self.get_new_pg(group_ranks=group_ranks, backend=backend) - global_rank = self.get_global_rank() - if global_rank in group_ranks: - logger.info( - f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}" - ) + logger.info( + f"initialized_group: create new group, pg_id = {pg_id}, group_ranks = {group_ranks}" + ) groups[pg_id] = pg # if additional groups are created, overwrite the default groups list diff --git a/et_replay/tools/comm_replay.py b/et_replay/tools/comm_replay.py index 96b22c88..02352857 100644 --- a/et_replay/tools/comm_replay.py +++ b/et_replay/tools/comm_replay.py @@ -31,7 +31,8 @@ from et_replay.comm.param_profile import paramProfile, paramTimer try: - from trainer_iteration_wrapper import setTrainingIteration # @manual + # pyre-ignore[21]: + from trainer_iteration_wrapper import setTrainingIteration except ImportError: pass @@ -89,6 +90,8 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non json.dump(commsTracePerf, write_file, indent=2) +# pyre-ignore[13]: lint complained about self.backendFuncs is never initlized. +# it is initialized in initBackend class commsTraceReplayBench(paramCommsBench): """ A class to replay and benchmark generated traces for collective communications.