Skip to content

Commit

Permalink
fix process group init logic and add pg_desc parsing (#190)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #190

This DIFF include the following changes:
    parse pg description.
    fix pg creation logic.

Pull Request resolved: #186

Test Plan: buck2 run  -c fbcode.nvcc_arch='h100a' mode/opt -c hpc_comms.use_ncclx=2.21.5 param_bench/train/comms/pt:launcher -- --launcher mast --dp networkai_mast_job_identity --cluster MastProdCluster --hw grandteton --nnode 8 --ppn 8 --module commsTraceReplay_v2 --trace-path manifold://pytorch_execution_trace/tree/traces/shengfu/nv/pattern2-64gpu --trace-type et --json_mast_flex_pool_id_override_map ~/flex_pool_long_cable.json --reuse-tensors --warmup-iter 5 --num-replays 10

Reviewed By: briancoutinho

Differential Revision: D67071961

Pulled By: shengfukevin

fbshipit-source-id: 3d52ba67244b89b3a53f46552cc2062002ce5f21
  • Loading branch information
TaekyungHeo authored and facebook-github-bot committed Dec 12, 2024
1 parent b691d61 commit ca00ca3
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 63 deletions.
5 changes: 0 additions & 5 deletions et_replay/comm/backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class collectiveArgsHolder:
def __init__(self) -> None:
self.group = None
self.groups = {} # {pg_id, pg}
self.num_pgs = 0
self.device = {}
self.world_size = 0
self.data_type = ""
Expand Down Expand Up @@ -291,10 +290,6 @@ def get_default_group(self) -> ProcessGroup:
def get_groups(self) -> list[ProcessGroup]:
pass

@abstractmethod
def get_num_pgs(self) -> int:
pass

# Init functions
@abstractmethod
def initialize_backend(
Expand Down
77 changes: 34 additions & 43 deletions et_replay/comm/backend/pytorch_dist_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import logging
import os
from collections import defaultdict

from itertools import cycle
from time import sleep

import numpy as np
Expand Down Expand Up @@ -37,6 +35,7 @@
has_ext_dist = False

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def _downcast(input, bitwidth):
Expand Down Expand Up @@ -831,12 +830,6 @@ def get_default_group(self):
def get_groups(self):
return self.groups

def get_num_pgs(self):
return self.num_pgs

def get_next_group(self):
return next(self.round_robin_group)

def set_device(self, local_rank, global_rank):
"""set current device: 'cpu' or 'cuda'"""
dev_str = (
Expand Down Expand Up @@ -943,14 +936,14 @@ def __init__(self, bootstrap_info, commsParams):
except ImportError:
raise RuntimeError("Unable to import Fairring")

def get_new_pg(self, group_ranks, backend):
def get_new_pg(self, group_ranks, backend, pg_desc=""):
if self.use_ext_dist:
return extend_distributed.new_extend_process_group(
ranks=group_ranks, backend=backend
)
else:
pg = dist.new_group(ranks=group_ranks, backend=backend)
return pg
pg = dist.new_group(ranks=group_ranks, backend=backend, group_desc=pg_desc)
return pg if pg is not dist.GroupMember.NON_GROUP_MEMBER else None

def tensor_list_to_numpy(self, tensorList):
if isinstance(tensorList, list):
Expand Down Expand Up @@ -1006,75 +999,73 @@ def initialize_backend(
# default 1 group, maybe overwritten by user created groups via initialize_groups
self.groups = {}
self.groups[0] = self.get_default_group()
self.num_pgs = len(self.groups)
self.round_robin_group = cycle(list(self.groups.values()))

def initialize_groups(self, backend="gloo"):
groups = {}
world_size = self.get_world_size()
global_rank = self.get_global_rank()

# map from group_rank to pgId, pgId of the groups in current rank is the pgId defined in
# ET, pgId of the groups from other ranks is -1.
group_rank_to_pgId: dict[tuple[int], list[int]] = defaultdict(list)
for pg_id, group_ranks in self.commsParams.groupRanks.items():
if group_ranks is None or len(group_ranks) == 0:
group_ranks = list(range(world_size))
group_ranks.sort()
rank_tuple = tuple(group_ranks)
group_rank_to_pgId[rank_tuple].append(pg_id)

# sync pgs across ranks to fix hang with multiple comm groups
# because new_group() function requires that all processes in the default group call it,
# even if they are not going to be members of the group.
sync_store = dist.PrefixStore("pg_sync_r", self.tcp_store)
sync_store.set(str(global_rank), json.dumps(self.commsParams.groupRanks))
torch.distributed.barrier()

idxed_group_ranks_to_pgId: dict[tuple[int], list[int]] = defaultdict(list)
for i in range(self.get_world_size()):
if i == global_rank:
continue
json_data = sync_store.get(str(i))

# convert pg_id in json_data to int
pg_id_to_group_ranks = {
int(pg_id): rank for pg_id, rank in json.loads(json_data).items()
}

for _, group_ranks in pg_id_to_group_ranks.items():
if group_ranks is None or len(group_ranks) == 0:
group_ranks = list(range(world_size))
# map from indexed group_ranks to pgId, pgId of the group in current rank is the pgId defined in
# ET, pgId of the group from other ranks is -1.
# index is used to differentiate several groups with the same ranks.
group_ranks_count: dict[tuple[int], int] = defaultdict(int)
for pg_id, group_ranks in dict(
sorted(pg_id_to_group_ranks.items())
).items():
group_ranks.sort()
rank_tuple = tuple(group_ranks)
group_rank_to_pgId[rank_tuple].append(-1)
count = group_ranks_count[rank_tuple]
group_ranks_count[rank_tuple] = count + 1
idxed_group_ranks_to_pgId[tuple(group_ranks + [count])].append(
pg_id if global_rank == i else -1
)

# create additional groups, sort it to make sure pg are created in the same order for all ranks
for group_ranks, pg_ids in dict(sorted(group_rank_to_pgId.items())).items():
for idxed_group_ranks, pg_ids in dict(
sorted(idxed_group_ranks_to_pgId.items())
).items():
if (
len(group_ranks) > world_size
len(idxed_group_ranks[:-1]) > world_size
): # this means that --auto-shrink is enabled, only use default pg
groups.clear()
break
if (
len(group_ranks) == world_size
): # this is the default group, it has already been created

pg_id = next((i for i in pg_ids if i != -1), -1)

if len(idxed_group_ranks[:-1]) == world_size and idxed_group_ranks[-1] == 0:
pg = self.get_default_group()
else:
pg = self.get_new_pg(group_ranks=list(group_ranks), backend=backend)
pg = self.get_new_pg(
group_ranks=list(idxed_group_ranks[:-1]),
backend=backend,
pg_desc=self.commsParams.pgsDesc.get(pg_id, ""),
)
logger.info(
f"initialized_group: create new group, pg_ids = {pg_ids}, group_ranks = {group_ranks}"
f"initialized_group: create new group, pg_ids = {pg_ids}, idxed_group_ranks = {idxed_group_ranks}"
)
for pg_id in pg_ids:
if pg_id != -1:
groups[pg_id] = pg
if pg_id != -1:
groups[pg_id] = pg

# if additional groups are created, overwrite the default groups list
if len(groups):
self.groups = groups

self.num_pgs = len(self.groups)

self.round_robin_group = cycle(list(self.groups.values()))

def benchmark_comms(self, benchTime, commsParams):
index = 0 # used in TPU, where it is not initialized!
if commsParams.init_only:
Expand Down
3 changes: 0 additions & 3 deletions et_replay/comm/backend/pytorch_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def get_default_group(self):
def get_groups(self):
pass

def get_num_pgs(self):
pass

def tensor_list_to_numpy(self, tensorList):
tensorList = torch.transpose(tensorList.view(-1, 1), 0, 1)[0]
return tensorList.cpu().detach().numpy()
Expand Down
35 changes: 24 additions & 11 deletions et_replay/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,58 +58,68 @@ def _parseExecutionTrace(
f"Only support trace version >1.0.3, but current trace version is {in_trace.schema.split('-')[0]}"
)

pg_ranks_map = _parse_proc_group_info(
in_trace
) # key is pg id, value is global ranks in this pg
# pg_ranks_map: key is pg id, value is global ranks in this pg
# pg_desc_map: key is pg id, value is pg desc
pg_ranks_map, pg_desc_map = _parse_proc_group_info(in_trace)
comms_op_list = _parse_comms_op_node(
in_trace, pg_ranks_map, target_rank, total_ranks
in_trace, pg_ranks_map, pg_desc_map, target_rank, total_ranks
)

return comms_op_list


def _parse_proc_group_info(in_trace: ExecutionTrace):
pg_ranks_map = {} # {node_id : {process_group_id : [ranks] } }
pg_desc_map = {} # {node_id : {process_group_id : pg_desc }
pg_init_nodes = (
node for node in in_trace.nodes.values() if "process_group:init" in node.name
)
for node in pg_init_nodes:
# info of this node is dumped using torch.distributed.distributed_c10d._world.pg_config_info
# at the start of profiling, but not not callback to torch.distributed.init_process_group()
# at the start of profiling, but not callback to torch.distributed.init_process_group()
# Pre-Assumption: all process groups has been created before profiling start.
try:
pg_objs = json.loads(node.inputs[0])
except json.decoder.JSONDecodeError: # skip if pg_config_info is truncated
break

pg_ranks_map[node.id] = {}
pg_desc_map[node.id] = {}
for pg in pg_objs:
if not pg["pg_name"].isdecimal():
# TODO support local synchronization pg
logger.warning(
f"Process group name is {pg['pg_name']} in node {node.id}, which is not supported. Skip."
)
continue
(pg_id, ranks, group_size, group_count) = (
pg[k] for k in ["pg_name", "ranks", "group_size", "group_count"]
(pg_id, pg_desc, ranks, group_size, group_count) = (
pg[k]
for k in ["pg_name", "pg_desc", "ranks", "group_size", "group_count"]
)
pg_id = int(pg_id)
pg_ranks_map[node.id][pg_id] = (
ranks if len(ranks) > 0 else list(range(group_size))
# rank list is empty when all ranks are in a pg
)
pg_desc_map[node.id][pg_id] = pg_desc
break # only one process_group init node per trace
return pg_ranks_map
return pg_ranks_map, pg_desc_map


def _parse_comms_op_node( # noqa: C901
in_trace: ExecutionTrace, pg_ranks_map: dict, target_rank: int, total_ranks: int
in_trace: ExecutionTrace,
pg_ranks_map: dict,
pg_desc_map: dict,
target_rank: int,
total_ranks: int,
):
comms_op_list = []

for node_id in pg_ranks_map:
for pg_id, ranks in pg_ranks_map[node_id].items():
comm_args = _create_pg_init_node(node_id, pg_id, ranks, len(ranks))
comm_args = _create_pg_init_node(
node_id, pg_id, ranks, pg_desc_map[node_id][pg_id], len(ranks)
)
comms_op_list.append(comm_args)

pg_ranks_map_flatten = {}
Expand Down Expand Up @@ -192,11 +202,14 @@ def _parse_comms_op_node( # noqa: C901
return comms_op_list


def _create_pg_init_node(node_id: int, pg_id: int, ranks: list[int], world_size: int):
def _create_pg_init_node(
node_id: int, pg_id: int, ranks: list[int], pg_desc: str, world_size: int
):
comm_args = commsArgs()
comm_args.id = node_id
comm_args.comms = "init"
comm_args.pgId = pg_id
comm_args.pgDesc = pg_desc
comm_args.req = -1
comm_args.groupRanks = ranks
comm_args.worldSize = world_size
Expand Down
2 changes: 2 additions & 0 deletions et_replay/comm/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def __init__(self, **kwargs) -> None:
self.outSplit = kwargs["outSplit"] if "outSplit" in kwargs else None
self.startTimeNs = kwargs["startTimeNs"] if "startTimeNs" in kwargs else None
self.pgId = kwargs["pgId"] if "pgId" in kwargs else None
self.pgDesc = kwargs["pgDesc"] if "pgDesc" in kwargs else None
self.groupRanks = kwargs["groupRanks"] if "groupRanks" in kwargs else None
self.worldSize = kwargs["worldSize"] if "worldSize" in kwargs else None
self.markerStack = kwargs["markerStack"] if "markerStack" in kwargs else None
Expand Down Expand Up @@ -692,6 +693,7 @@ def __init__(self, args: Namespace) -> None:
self.quant_threshold = args.quant_threshold
self.dcheck = args.c
self.groupRanks = {} # record what ranks each process group will work on {pg_id, ranks}
self.pgsDesc = {} # {pg_id: pg_desc}
self.use_ext_dist = args.use_ext_dist
self.size_from_trace = False
self.init_method = args.init_method
Expand Down
2 changes: 1 addition & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,6 +1403,7 @@ def setBench(
# record process group info
if curComm.comms == "init":
commsParams.groupRanks[curComm.pgId] = curComm.groupRanks
commsParams.pgsDesc[curComm.pgId] = curComm.pgDesc
self.backendFuncs.initialize_groups(commsParams.backend)

# set basic collective info
Expand All @@ -1419,7 +1420,6 @@ def setBench(

self.collectiveArgs.group = group # default group
self.collectiveArgs.groups = self.backendFuncs.get_groups()
self.collectiveArgs.num_pgs = self.backendFuncs.get_num_pgs()
self.collectiveArgs.device = curDevice
self.collectiveArgs.world_size = world_size
self.collectiveArgs.global_rank = global_rank
Expand Down

0 comments on commit ca00ca3

Please sign in to comment.