Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to multiple process groups by syncing across ranks #151

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def alloc_ones(
ipTensor = ipTensor * scaleFactor
return ipTensor

@abstractmethod
def noop(
self,
collectiveArgs: collectiveArgsHolder = None,
Expand Down
36 changes: 29 additions & 7 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import os

from itertools import cycle
from time import sleep
from typing import List, Optional
Expand Down Expand Up @@ -1008,7 +1010,8 @@ def get_new_pg(self, group_ranks, backend):
ranks=group_ranks, backend=backend
)
else:
return dist.new_group(ranks=group_ranks, backend=backend)
pg = dist.new_group(ranks=group_ranks, backend=backend)
return pg

def tensor_list_to_numpy(self, tensorList):
if isinstance(tensorList, list):
Expand Down Expand Up @@ -1070,9 +1073,30 @@ def initialize_backend(
def initialize_groups(self, backend="gloo"):
groups = {}
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() functions requires that all processes in the main group enter,
# even if they are not going to be members of the group.
# Assumption: pg_name is unique and consistent for all ranks
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()
group_ranks_sync = self.commsParams.groupRanks.copy()
for i in range(self.get_world_size()):
if i == global_rank:
continue
json_data = sync_store.get(str(i))

# convert pg_id in json_data to int
pg_id2group_ranks = {
int(pg_id): rank for pg_id, rank in json.loads(json_data).items()
}

group_ranks_sync.update(pg_id2group_ranks)

# create additional groups
for pg_id, group_ranks in self.commsParams.groupRanks.items():
for pg_id, group_ranks in dict(sorted(group_ranks_sync.items())).items():
shengfukevin marked this conversation as resolved.
Show resolved Hide resolved
if (
len(group_ranks) > world_size
): # this means that --auto-shrink is enabled, only use default pg
Expand All @@ -1084,11 +1108,9 @@ def initialize_groups(self, backend="gloo"):
pg = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=group_ranks, backend=backend)
global_rank = self.get_global_rank()
if global_rank in group_ranks:
logger.info(
f"initialize_groups: Rank {global_rank} creates new group pg_id {pg_id} {pg} with {group_ranks}"
)
logger.info(
f"initialized_group: create new group, pg_id = {pg_id}, group_ranks = {group_ranks}"
)
shengfukevin marked this conversation as resolved.
Show resolved Hide resolved
groups[pg_id] = pg

# if additional groups are created, overwrite the default groups list
Expand Down
5 changes: 4 additions & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from et_replay.comm.param_profile import paramProfile, paramTimer

try:
from trainer_iteration_wrapper import setTrainingIteration # @manual
# pyre-ignore[21]:
from trainer_iteration_wrapper import setTrainingIteration
except ImportError:
pass

Expand Down Expand Up @@ -89,6 +90,8 @@ def writeCommDetails(commsTracePerf: List, rank: int, folder: str = "./") -> Non
json.dump(commsTracePerf, write_file, indent=2)


# pyre-ignore[13]: lint complained about self.backendFuncs is never initlized.
# it is initialized in initBackend
class commsTraceReplayBench(paramCommsBench):
"""
A class to replay and benchmark generated traces for collective communications.
Expand Down
Loading