Skip to content

Commit

Permalink
Function refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed May 14, 2024
1 parent 5d2a8ca commit 20f7273
Show file tree
Hide file tree
Showing 18 changed files with 2,052 additions and 3,876 deletions.
244 changes: 101 additions & 143 deletions et_replay/lib/chakra_trace_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Tuple

from et_replay.comm import comm_utils
from et_replay.comm.comm_utils import commsArgs
from et_replay.comm.comm_utils import CommArgs
from et_replay.comm.pytorch_backend_utils import SupportedP2pOps

from param_bench.train.compute.python.tools.execution_trace import ExecutionTrace
Expand Down Expand Up @@ -61,99 +61,57 @@ def _parseBasicTrace(in_trace: List):
"""
Convert Basic Trace to comms trace format.
"""
newCommsTrace = []
for cnt, curComm in enumerate(in_trace):
new_commsTrace = []
for cnt, curr_comm in enumerate(in_trace):

newComm = commsArgs()
newComm.id = cnt
newComm.markerStack = curComm.get("markers")
if "comms" in curComm:
_parseBasicTraceComms(curComm, newComm)
new_comm = CommArgs()
new_comm.id = cnt
new_comm.markerStack = curr_comm.get("markers")
if "comms" in curr_comm:
_parseBasicTraceComms(curr_comm, new_comm)

elif "compute" in curComm:
_parseBasicTraceCompute(curComm, newComm)

if newComm.comms is not None or newComm.compute is not None:
newCommsTrace.append(newComm)
if new_comm.comms is not None or new_comm.compute is not None:
new_commsTrace.append(new_comm)
else:
raise ValueError(
"Trace file contains an element that is not a supported in PARAM! Please format all elements as comms or compute for replay."
)

return newCommsTrace

return new_commsTrace

def _parseBasicTraceComms(curComm, newComm: commsArgs) -> None:

newComm.comms = comm_utils.paramToCommName(curComm["comms"].lower())
if newComm.markerStack is None:
newComm.markerStack = [newComm.comms]
newComm.req = curComm.get("req")
newComm.startTimeNs = curComm.get("startTime_ns")
newComm.worldSize = curComm.get("world_size")
newComm.root = curComm.get("root")
newComm.pgId = curComm.get("pg_id")
newComm.groupRanks = curComm.get("global_ranks")
def _parseBasicTraceComms(curr_comm, new_comm: CommArgs) -> None:

if newComm.comms not in ("wait", "barrier", "init", "batch_isend_irecv"):
newComm.inMsgSize = curComm["in_msg_size"]
newComm.outMsgSize = curComm["out_msg_size"]
newComm.dtype = curComm["dtype"].lower()
new_comm.comms = comm_utils.standardize_comm_name(curr_comm["comms"].lower())
if new_comm.markerStack is None:
new_comm.markerStack = [new_comm.comms]
new_comm.req = curr_comm.get("req")
new_comm.startTimeNs = curr_comm.get("startTime_ns")
new_comm.worldSize = curr_comm.get("world_size")
new_comm.root = curr_comm.get("root")
new_comm.pgId = curr_comm.get("pg_id")
new_comm.groupRanks = curr_comm.get("global_ranks")

if newComm.comms == "all_to_allv":
newComm.inSplit = curComm["in_split"]
newComm.outSplit = curComm["out_split"]
if new_comm.comms not in ("wait", "barrier", "init", "batch_isend_irecv"):
new_comm.in_msg_size = curr_comm["in_msg_size"]
new_comm.out_msg_size = curr_comm["out_msg_size"]
new_comm.dtype = curr_comm["dtype"].lower()

if newComm.comms in SupportedP2pOps:
newComm.src_rank = curComm["src_rank"]
newComm.dst_rank = curComm["dst_rank"]
newComm.batch_p2p = curComm["use_batch"]
if new_comm.comms == "all_to_allv":
new_comm.inSplit = curr_comm["in_split"]
new_comm.outSplit = curr_comm["out_split"]


def _parseBasicTraceCompute(curComm, newComm: commsArgs) -> None:
newComm.compute = curComm["compute"].lower()
if newComm.markerStack is None:
newComm.markerStack = [newComm.compute]
# count = number of times to call the compute kernel
if "count" in curComm:
newComm.count = curComm["count"]
# if no count is specified, assume 1
else:
newComm.count = 1
if newComm.compute == "gemm":
if "mm_dim" in curComm:
newComm.mm0_dim0 = curComm.get("mm_dim")
newComm.mm0_dim1 = curComm.get("mm_dim")
newComm.mm1_dim0 = curComm.get("mm_dim")
newComm.mm1_dim1 = curComm.get("mm_dim")
else:
newComm.mm0_dim0 = curComm.get("mm0_dim0")
newComm.mm0_dim1 = curComm.get("mm0_dim1")
newComm.mm1_dim0 = curComm.get("mm1_dim0")
newComm.mm1_dim1 = curComm.get("mm1_dim1")
newComm.dtype = curComm.get("dtype").lower()
elif newComm.compute == "emb_lookup":
if "direction" in curComm:
newComm.direction = curComm["direction"]
else:
newComm.direction = "forward"
newComm.emb_dim = curComm.get("emb_dim")
newComm.num_embs = curComm.get("num_embs")
newComm.batch_size = curComm.get("batch_size")
newComm.num_emb_tables_per_device = curComm.get("num_emb_tables")
newComm.num_emb_tables_batched = -1
newComm.bag_size = curComm.get("bag_size")
else:
raise ValueError(
f"Trace file contains {str(newComm.compute)} compute element that is not supported in PARAM!"
)
if new_comm.comms in SupportedP2pOps:
new_comm.src_rank = curr_comm["src_rank"]
new_comm.dst_rank = curr_comm["dst_rank"]
new_comm.batch_p2p = curr_comm["use_batch"]


def _parseKinetoUnitrace(in_trace: List, target_rank: int) -> List:
"""
Convert the Kineto unitrace w/ comms metadata to the clean common trace format for replay.
"""
newCommsTrace = []
new_commsTrace = []
commsCnt = 0
for entry in in_trace:
# TODO: figure the current marker stack if present
Expand All @@ -166,20 +124,20 @@ def _parseKinetoUnitrace(in_trace: List, target_rank: int) -> List:
and entry["args"]["rank"] == target_rank
):

newComm = commsArgs()
newComm.comms = comm_utils.paramToCommName(entry["args"]["comms"].lower())
newComm.id = commsCnt
newComm.inMsgSize = entry["args"]["in_msg_size"]
newComm.outMsgSize = entry["args"]["out_msg_size"]
newComm.dtype = entry["args"]["dtype"].lower()
newComm.inSplit = entry["args"]["in_split"]
newComm.outSplit = entry["args"]["out_split"]
newComm.markerStack = marker

newCommsTrace.append(newComm)
new_comm = CommArgs()
new_comm.comms = comm_utils.standardize_comm_name(entry["args"]["comms"].lower())
new_comm.id = commsCnt
new_comm.in_msg_size = entry["args"]["in_msg_size"]
new_comm.out_msg_size = entry["args"]["out_msg_size"]
new_comm.dtype = entry["args"]["dtype"].lower()
new_comm.inSplit = entry["args"]["in_split"]
new_comm.outSplit = entry["args"]["out_split"]
new_comm.markerStack = marker

new_commsTrace.append(new_comm)
commsCnt += 1

return newCommsTrace
return new_commsTrace


def _getTensorInfoFromPyTorchETEntry(
Expand Down Expand Up @@ -221,11 +179,11 @@ def _parseExecutionTrace(
ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3)
ET_BACKENDID = in_trace.schema_pytorch() < (1, 0, 3)

initOps = []
newCommsTrace = []
init_ops = []
new_commsTrace = []
backendIdToPgid = {}
pgRanksMap = {}
groupCnt = -1
gruop_cnt = -1

# Parse PG info from ET
for node in in_trace.nodes.values():
Expand All @@ -242,7 +200,7 @@ def _parseExecutionTrace(
continue
pgId = int(pg["pg_name"])
ranks = pg["ranks"]
groupCnt = pg["group_count"]
gruop_cnt = pg["group_count"]
pgRanksMap[pgId] = (
ranks
if len(ranks) > 0
Expand All @@ -260,14 +218,14 @@ def _parseExecutionTrace(
shift = (
0 if len(node.inputs) == 8 or len(node.inputs) == 10 else 1
) # wait/barrier ops do not have an input tensor (len=7), shift index one over
newComm = commsArgs()
newComm.id = node.id
newComm.comms = comm_utils.paramToCommName(
new_comm = CommArgs()
new_comm.id = node.id
new_comm.comms = comm_utils.standardize_comm_name(
node.inputs[4 - shift].lower()
) # 5th value of inputs is colName
if newComm.comms == "init":
if new_comm.comms == "init":
continue
newComm.req = node.inputs[
new_comm.req = node.inputs[
1 - shift
] # 2nd value of inputs is the req id of the collective

Expand All @@ -276,87 +234,87 @@ def _parseExecutionTrace(
] # 3rd value of inputs is the pg identifier of the collective
# Assign pg_id info for PGs that were created.
if ET_BACKENDID and pgIdentifier in backendIdToPgid:
newComm.pgId = backendIdToPgid[pgIdentifier]
newComm.groupRanks = pgRanksMap[newComm.pgId]
newComm.worldSize = len(newComm.groupRanks)
new_comm.pgId = backendIdToPgid[pgIdentifier]
new_comm.groupRanks = pgRanksMap[new_comm.pgId]
new_comm.worldSize = len(new_comm.groupRanks)
elif ET_PG_NAME_TUPLE and pgIdentifier[0].isdecimal():
newComm.pgId = int(pgIdentifier[0])
newComm.groupRanks = pgRanksMap[newComm.pgId]
newComm.worldSize = len(newComm.groupRanks)
new_comm.pgId = int(pgIdentifier[0])
new_comm.groupRanks = pgRanksMap[new_comm.pgId]
new_comm.worldSize = len(new_comm.groupRanks)

if newComm.comms not in ("wait", "barrier"):
if new_comm.comms not in ("wait", "barrier"):
(
newComm.inMsgSize,
inMsgType,
new_comm.in_msg_size,
in_msg_type,
) = _getTensorInfoFromPyTorchETEntry(node.inputs, node.input_types[0])
(
newComm.outMsgSize,
new_comm.out_msg_size,
_,
) = _getTensorInfoFromPyTorchETEntry(node.outputs, node.output_types[0])
newComm.dtype = tensorDtypeMap[
inMsgType
new_comm.dtype = tensorDtypeMap[
in_msg_type
] # 1st value of input_types is the data type for the tensors

if newComm.comms in SupportedP2pOps:
if "send" in newComm.comms:
newComm.src_rank = target_rank
if new_comm.comms in SupportedP2pOps:
if "send" in new_comm.comms:
new_comm.src_rank = target_rank
local_dst_rank = node.inputs[3 - shift]
newComm.dst_rank = newComm.groupRanks[local_dst_rank]
if "recv" in newComm.comms:
new_comm.dst_rank = new_comm.groupRanks[local_dst_rank]
if "recv" in new_comm.comms:
local_src_rank = node.inputs[3 - shift]
newComm.src_rank = newComm.groupRanks[local_src_rank]
newComm.dst_rank = target_rank
new_comm.src_rank = new_comm.groupRanks[local_src_rank]
new_comm.dst_rank = target_rank

if newComm.comms == "broadcast":
newComm.root = newComm.groupRanks[0]
newComm.srcOrDst = newComm.groupRanks[0]
if new_comm.comms == "broadcast":
new_comm.root = new_comm.groupRanks[0]
new_comm.srcOrDst = new_comm.groupRanks[0]

if newComm.comms == "all_to_allv":
if new_comm.comms == "all_to_allv":
# 6th value of inputs is in_split, split evenly if not provided
if not newComm.worldSize:
if not new_comm.worldSize:
# if no pg info provided, use total ranks as world size
newComm.worldSize = total_ranks
newComm.inSplit = (
new_comm.worldSize = total_ranks
new_comm.inSplit = (
node.inputs[5]
if node.inputs[5]
else [int(newComm.inMsgSize / newComm.worldSize)]
* newComm.worldSize
else [int(new_comm.in_msg_size / new_comm.worldSize)]
* new_comm.worldSize
)
# 7th value of inputs is out_split, split evenly if not provided
newComm.outSplit = (
new_comm.outSplit = (
node.inputs[6]
if node.inputs[6]
else [int(newComm.outMsgSize / newComm.worldSize)]
* newComm.worldSize
else [int(new_comm.out_msg_size / new_comm.worldSize)]
* new_comm.worldSize
)
newCommsTrace.append(newComm)
new_commsTrace.append(new_comm)

# Build init node
initOps = []
if groupCnt < 0:
init_ops = []
if gruop_cnt < 0:
# old format: To be removed
for pgId, ranks in pgRanksMap.items():
newComm = create_pg_init_node(pgId, ranks, len(ranks))
initOps.append(newComm)
new_comm = create_pg_init_node(pgId, ranks, len(ranks))
init_ops.append(new_comm)
else:
for pgId in range(groupCnt):
for pgId in range(gruop_cnt):
if pgId in pgRanksMap:
ranks = pgRanksMap[pgId]
else:
# create a dummy pg that the current rank is not part of
ranks = [0] if target_rank != 0 else [1]

newComm = create_pg_init_node(pgId, ranks, len(ranks))
initOps.append(newComm)
new_comm = create_pg_init_node(pgId, ranks, len(ranks))
init_ops.append(new_comm)

return initOps + newCommsTrace
return init_ops + new_commsTrace


def create_pg_init_node(pg_id: int, ranks: List[int], world_size: int):
newComm = commsArgs()
newComm.comms = "init"
newComm.pgId = pg_id
newComm.req = -1
newComm.groupRanks = ranks
newComm.worldSize = world_size
return newComm
new_comm = CommArgs()
new_comm.comms = "init"
new_comm.pgId = pg_id
new_comm.req = -1
new_comm.groupRanks = ranks
new_comm.worldSize = world_size
return new_comm
Loading

0 comments on commit 20f7273

Please sign in to comment.