Skip to content

Commit

Permalink
refactor process group init to:
Browse files Browse the repository at this point in the history
1. parse pg description.
2. fix pg creation logic.
  • Loading branch information
sanshang-nv committed Nov 1, 2024
1 parent eb32a65 commit 61009fb
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 74 deletions.
5 changes: 0 additions & 5 deletions et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class collectiveArgsHolder:
def __init__(self) -> None:
self.group = None
self.groups = {} # {pg_id, pg}
self.num_pgs = 0
self.device = {}
self.world_size = 0
self.data_type = ""
Expand Down Expand Up @@ -294,10 +293,6 @@ def get_default_group(self) -> ProcessGroup:
def get_groups(self) -> List[ProcessGroup]:
pass

@abstractmethod
def get_num_pgs(self) -> int:
pass

# Init functions
@abstractmethod
def initialize_backend(
Expand Down
88 changes: 33 additions & 55 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
has_ext_dist = False

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _downcast(input, bitwidth):
Expand Down Expand Up @@ -832,12 +833,6 @@ def get_default_group(self):
def get_groups(self):
return self.groups

def get_num_pgs(self):
return self.num_pgs

def get_next_group(self):
return next(self.round_robin_group)

def set_device(self, local_rank, global_rank):
"""set current device: 'cpu' or 'cuda'"""
dev_str = (
Expand Down Expand Up @@ -944,14 +939,14 @@ def __init__(self, bootstrap_info, commsParams):
except ImportError:
raise RuntimeError("Unable to import Fairring")

def get_new_pg(self, group_ranks, backend):
def get_new_pg(self, group_ranks, backend, pg_desc=''):
if self.use_ext_dist:
return extend_distributed.new_extend_process_group(
ranks=group_ranks, backend=backend
)
else:
pg = dist.new_group(ranks=group_ranks, backend=backend)
return pg
pg = dist.new_group(ranks=group_ranks, backend=backend, group_desc=pg_desc)
return pg if pg is not dist.GroupMember.NON_GROUP_MEMBER else None

def tensor_list_to_numpy(self, tensorList):
if isinstance(tensorList, list):
Expand Down Expand Up @@ -1007,75 +1002,58 @@ def initialize_backend(
# default 1 group, maybe overwritten by user created groups via initialize_groups
self.groups = {}
self.groups[0] = self.get_default_group()
self.num_pgs = len(self.groups)
self.round_robin_group = cycle(list(self.groups.values()))

def initialize_groups(self, backend="gloo"):
groups = {}
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# map from group_rank to pgId, pgId of the groups in current rank is the pgId defined in
# ET, pgId of the groups from other ranks is -1.
group_rank_to_pgId: Dict[Tuple[int], List[int]] = defaultdict(list)
for pg_id, group_ranks in self.commsParams.groupRanks.items():
if group_ranks is None or len(group_ranks) == 0:
group_ranks = list(range(world_size))
group_ranks.sort()
rank_tuple = tuple(group_ranks)
group_rank_to_pgId[rank_tuple].append(pg_id)

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() function requires that all processes in the default group call it,
# even if they are not going to be members of the group.

sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()

global_pg_id_to_ranks = {k: v for k,v in self.commsParams.groupRanks.items()}
for i in range(self.get_world_size()):
if i == global_rank:
continue
json_data = sync_store.get(str(i))

# convert pg_id in json_data to int
pg_id_to_group_ranks = {
int(pg_id): rank for pg_id, rank in json.loads(json_data).items()
int(pg_id): ranks for pg_id, ranks in json.loads(json_data).items()
}

for _, group_ranks in pg_id_to_group_ranks.items():
if group_ranks is None or len(group_ranks) == 0:
group_ranks = list(range(world_size))
group_ranks.sort()
rank_tuple = tuple(group_ranks)
group_rank_to_pgId[rank_tuple].append(-1)

# create additional groups, sort it to make sure pg are created in the same order for all ranks
for group_ranks, pg_ids in dict(sorted(group_rank_to_pgId.items())).items():
if (
len(group_ranks) > world_size
): # this means that --auto-shrink is enabled, only use default pg
for k,v in pg_id_to_group_ranks.items():
if k not in global_pg_id_to_ranks:
global_pg_id_to_ranks[k] = v
logger.info(f'local process groups are: {self.commsParams.groupRanks}')
logger.info(f'synchronized global process groups are: {global_pg_id_to_ranks}')

# pg_id = 0 is always reserved for default pg.
# if you create another pg including all ranks, the ranks list in trace is empty, but
# it's another pg with different pg_id.
for pg_id in sorted(list(global_pg_id_to_ranks.keys())):
group_ranks = global_pg_id_to_ranks[pg_id]
if len(group_ranks) > world_size:
# this means that --auto-shrink is enabled, only use default pg
groups.clear()
break
if (
len(group_ranks) == world_size
): # this is the default group, it has already been created
pg = self.get_default_group()
if pg_id == 0:
groups[pg_id] = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=list(group_ranks), backend=backend)
logger.info(
f"initialized_group: create new group, pg_ids = {pg_ids}, group_ranks = {group_ranks}"
)
for pg_id in pg_ids:
if pg_id != -1:
# include padding invokation to torch.distributed.new_group() for its requirements:
# This function requires that all processes in the main group enter this function,
# even if they are not going to be members of the group. Additionally, groups should
# be created in the same order in all processes.
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.new_group
pg = self.get_new_pg(group_ranks=group_ranks, backend=backend, pg_desc=self.commsParams.pgsDesc.get(pg_id, ''))
if pg is not None:
assert pg_id in self.commsParams.groupRanks, \
f"global pg (pg_id={pg_id}, pg_ranks={group_ranks}) is not recorded in rank_{global_rank}'s init node list"
logger.info(f"initialized_group: create new group, pg_id={pg_id}, ranks = {group_ranks}")
groups[pg_id] = pg

# if additional groups are created, overwrite the default groups list
if len(groups):
self.groups = groups

self.num_pgs = len(self.groups)

self.round_robin_group = cycle(list(self.groups.values()))

def benchmark_comms(self, benchTime, commsParams):
index = 0 # used in TPU, where it is not initialized!
if commsParams.init_only:
Expand Down
3 changes: 0 additions & 3 deletions et_replay/comm/backend/pytorch_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def get_default_group(self):
def get_groups(self):
pass

def get_num_pgs(self):
pass

def tensor_list_to_numpy(self, tensorList):
tensorList = torch.transpose(tensorList.view(-1, 1), 0, 1)[0]
return tensorList.cpu().detach().numpy()
Expand Down
26 changes: 16 additions & 10 deletions et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,40 +59,44 @@ def _parseExecutionTrace(
f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}"
)

pg_ranks_map = _parse_proc_group_info(
# pg_ranks_map: key is pg id, value is global ranks in this pg
# pg_desc_map: key is pg id, value is pg desc
pg_ranks_map, pg_desc_map = _parse_proc_group_info(
in_trace
) # key is pg id, value is global ranks in this pg
)
comms_op_list = _parse_comms_op_node(
in_trace, pg_ranks_map, target_rank, total_ranks
in_trace, pg_ranks_map, pg_desc_map, target_rank, total_ranks
)

return comms_op_list


def _parse_proc_group_info(in_trace: ExecutionTrace):
pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } }
pg_desc_map = {} # {node_id : {process_group_id : pg_desc }
pg_init_nodes = (
node for node in in_trace.nodes.values() if "process_group:init" in node.name
)
for node in pg_init_nodes:
# info of this node is dumped using torch.distributed.distributed_c10d._world.pg_config_info
# at the start of profiling, but not not callback to torch.distributed.init_process_group()
# at the start of profiling, but not callback to torch.distributed.init_process_group()
# Pre-Assumption: all process groups has been created before profiling start.
try:
pg_objs = json.loads(node.inputs[0])
except json.decoder.JSONDecodeError: # skip if pg_config_info is truncated
break

pg_ranks_map[node.id] = {}
pg_desc_map[node.id] = {}
for pg in pg_objs:
if not pg["pg_name"].isdecimal():
# TODO support local synchronization pg
logger.warning(
f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip."
)
continue
(pg_id, ranks, group_size, group_count) = [
pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"]
(pg_id, pg_desc, ranks, group_size, group_count) = [
pg[k] for k in ["pg_name", "pg_desc", "ranks", "group_size", "group_count"]
]
pg_id = int(pg_id)
pg_ranks_map[node.id][pg_id] = (
Expand All @@ -101,18 +105,19 @@ def _parse_proc_group_info(in_trace: ExecutionTrace):
else list(range(group_size))
# rank list is empty when all ranks are in a pg
)
pg_desc_map[node.id][pg_id] = pg_desc
break # only one process_group init node per trace
return pg_ranks_map
return pg_ranks_map, pg_desc_map


def _parse_comms_op_node( # noqa: C901
in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int
in_trace: ExecutionTrace, pg_ranks_map: dict, pg_desc_map: dict, target_rank: int, total_ranks: int
):
comms_op_list = []

for node_id in pg_ranks_map:
for pg_id, ranks in pg_ranks_map[node_id].items():
comm_args = _create_pg_init_node(node_id, pg_id, ranks, len(ranks))
comm_args = _create_pg_init_node(node_id, pg_id, ranks, pg_desc_map[node_id][pg_id], len(ranks))
comms_op_list.append(comm_args)

pg_ranks_map_flatten = {}
Expand Down Expand Up @@ -195,11 +200,12 @@ def _parse_comms_op_node( # noqa: C901
return comms_op_list


def _create_pg_init_node(node_id: int, pg_id: int, ranks: List[int], world_size: int):
def _create_pg_init_node(node_id: int, pg_id: int, ranks: List[int], pg_desc: str, world_size: int):
comm_args = commsArgs()
comm_args.id = node_id
comm_args.comms = "init"
comm_args.pgId = pg_id
comm_args.pgDesc = pg_desc
comm_args.req = -1
comm_args.groupRanks = ranks
comm_args.worldSize = world_size
Expand Down
2 changes: 2 additions & 0 deletions et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ def __init__(self, **kwargs) -> None:
self.outSplit = kwargs["outSplit"] if "outSplit" in kwargs else None
self.startTimeNs = kwargs["startTimeNs"] if "startTimeNs" in kwargs else None
self.pgId = kwargs["pgId"] if "pgId" in kwargs else None
self.pgDesc = kwargs["pgDesc"] if "pgDesc" in kwargs else None
self.groupRanks = kwargs["groupRanks"] if "groupRanks" in kwargs else None
self.worldSize = kwargs["worldSize"] if "worldSize" in kwargs else None
self.markerStack = kwargs["markerStack"] if "markerStack" in kwargs else None
Expand Down Expand Up @@ -699,6 +700,7 @@ def __init__(self, args: Namespace) -> None:
self.groupRanks = (
{}
) # record what ranks each process group will work on {pg_id, ranks}
self.pgsDesc = {} # {pg_id: pg_desc}
self.use_ext_dist = args.use_ext_dist
self.size_from_trace = False
self.init_method = args.init_method
Expand Down
2 changes: 1 addition & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,7 @@ def setBench(
# record process group info
if curComm.comms == "init":
commsParams.groupRanks[curComm.pgId] = curComm.groupRanks
commsParams.pgsDesc[curComm.pgId] = curComm.pgDesc
self.backendFuncs.initialize_groups(commsParams.backend)

# set basic collective info
Expand All @@ -1415,7 +1416,6 @@ def setBench(

self.collectiveArgs.group = group # default group
self.collectiveArgs.groups = self.backendFuncs.get_groups()
self.collectiveArgs.num_pgs = self.backendFuncs.get_num_pgs()
self.collectiveArgs.device = curDevice
self.collectiveArgs.world_size = world_size
self.collectiveArgs.global_rank = global_rank
Expand Down

0 comments on commit 61009fb

Please sign in to comment.