Skip to content

Commit

Permalink
Enable comm_replay in PARAM by Integrating and Refactoring Comm Code (f…
Browse files Browse the repository at this point in the history
…acebookresearch#112)

Summary:
X-link: facebookresearch/HolisticTraceAnalysis#137

- **Code Migration**: Copied all comm_replay-related code from train/comms/pt to et_replay/lib/comm. The decision to copy rather than create symbolic links was mandatory to avoid dependency issues and maintain a stable and self-contained code environment, ensuring that the et_replay project remains functional even if the source files change.
- **Code Cleanup**: Removed obsolete files such as dlrm.py and comms.py to streamline the codebase.
- **Configuration Update**: Modified import statements and updated pyproject.toml to align with the new directory structure, ensuring proper package management.

Pull Request resolved: facebookresearch#112

Test Plan:
```
$ pip install .
Processing /Users/theo/param/et_replay
  Installing build dependencies ... done
  Getting requirements to build wheel ... done
  Installing backend dependencies ... done
  Preparing metadata (pyproject.toml) ... done
Building wheels for collected packages: et_replay
  Building wheel for et_replay (pyproject.toml) ... done
  Created wheel for et_replay: filename=et_replay-1.0.0-py3-none-any.whl size=61490 sha256=d4e4433c55487d790e6bb1bb892eca268348a148f3365d3587fac90aa38692ee
  Stored in directory: /private/var/folders/z0/c9mq5j4s6n14n0_gs7nlt6mc0000gp/T/pip-ephem-wheel-cache-jxux47rn/wheels/3b/3f/aa/d3fc853f83c22c6f3eeb09763570c2cc8031a1a414cb3c18b6
Successfully built et_replay
Installing collected packages: et_replay
  Attempting uninstall: et_replay
    Found existing installation: et_replay 1.0.0
    Uninstalling et_replay-1.0.0:
      Successfully uninstalled et_replay-1.0.0
Successfully installed et_replay-1.0.0

$ comm_replay
[BLOCKED as expected]
```

Run on mast
buck2 run mode/opt -c hpc_comms.use_ncclx=2.18.3 param_bench/train/comms/pt:launcher -- --launcher mast --cluster=MastProdCluster --dp networkai_mast_job_identity --hw tc_any --nnode 8 --ppn 8 --z=0 --module commsTraceReplay --trace-path manifold://param/tree/shengbao/et/torchx-conda-xlformers_ncclexp_70b_fp8_fsdp_pp_ctran_ag-tgqvxwkz --trace-type et --reuse-tensors

https://www.internalfb.com/mlhub/pipelines/runs/mast/torchx-param-commsTraceReplay-64gpus-allreduce-5f66a4?job_attempt=0&version=0&tab=scheduling&env=PRODUCTION

Reviewed By: shengbao-zheng

Differential Revision: D57354772

Pulled By: briancoutinho

fbshipit-source-id: f4563f6f4823e8f8b097d68aa35da3461aa4c0a0
  • Loading branch information
TaekyungHeo authored and facebook-github-bot committed May 17, 2024
1 parent 675754e commit e99ef20
Show file tree
Hide file tree
Showing 14 changed files with 5,705 additions and 90 deletions.
2 changes: 2 additions & 0 deletions et_replay/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
build/
et_replay.egg-info/
362 changes: 362 additions & 0 deletions et_replay/lib/comm/commsTraceParser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
from __future__ import annotations

import json

from typing import List, Tuple

from et_replay.lib.comm import comms_utils
from et_replay.lib.comm.comms_utils import commsArgs
from et_replay.lib.comm.pytorch_backend_utils import supportedP2pOps

from et_replay.lib.execution_trace import ExecutionTrace

tensorDtypeMap = {
"Tensor(int)": "int",
"Tensor(float)": "float",
"Tensor(bool)": "bool",
"Tensor(long)": "long",
"Tensor(long int)": "long",
"Tensor(double)": "double",
"Tensor(half)": "half",
"Tensor(byte)": "byte",
"Tensor(c10::Half)": "half",
"Tensor(c10::BFloat16)": "bfloat16",
"Tensor(unsigned char)": "char",
"Tensor(signed char)": "char",
}


def parseTrace(
in_trace: List, trace_type: str, target_rank: int, total_ranks: int
) -> List:
"""
Parse trace files to be compatible with PARAM replay-mode.
Currently supports: Basic Trace, Kineto Unitrace, and PyTorch ET trace.
Args:
in_trace: Trace file to be parsed.
trace_type: Trace type to be parsed with
target_rank: The current rank of the device.
total_ranks: Total number of ranks.
Returns:
parsed_trace: Parsed trace that is compatible with PARAM replay-mode.
"""

if trace_type == "basic": # Basic Trace
parsed_trace = _parseBasicTrace(in_trace)
elif trace_type == "et": # Execution Trace (e.g. PyTorch ET, Chakra)
parsed_trace = _parseExecutionTrace(
ExecutionTrace(in_trace), target_rank, total_ranks
)
elif trace_type == "kineto": # Kineto Unitrace
parsed_trace = _parseKinetoUnitrace(in_trace, target_rank)
else:
raise ValueError("Unrecognized trace format.")

return parsed_trace


def _parseBasicTrace(in_trace: List):
"""
Convert Basic Trace to comms trace format.
"""
newCommsTrace = []
for cnt, curComm in enumerate(in_trace):

newComm = commsArgs()
newComm.id = cnt
newComm.markerStack = curComm.get("markers")
if "comms" in curComm:
_parseBasicTraceComms(curComm, newComm)

elif "compute" in curComm:
_parseBasicTraceCompute(curComm, newComm)

if newComm.comms is not None or newComm.compute is not None:
newCommsTrace.append(newComm)
else:
raise ValueError(
"Trace file contains an element that is not a supported in PARAM! Please format all elements as comms or compute for replay."
)

return newCommsTrace


def _parseBasicTraceComms(curComm, newComm: commsArgs) -> None:

newComm.comms = comms_utils.paramToCommName(curComm["comms"].lower())
if newComm.markerStack is None:
newComm.markerStack = [newComm.comms]
newComm.req = curComm.get("req")
newComm.startTimeNs = curComm.get("startTime_ns")
newComm.worldSize = curComm.get("world_size")
newComm.root = curComm.get("root")
newComm.pgId = curComm.get("pg_id")
newComm.groupRanks = curComm.get("global_ranks")

if newComm.comms not in ("wait", "barrier", "init", "batch_isend_irecv"):
newComm.inMsgSize = curComm["in_msg_size"]
newComm.outMsgSize = curComm["out_msg_size"]
newComm.dtype = curComm["dtype"].lower()

if newComm.comms == "all_to_allv":
newComm.inSplit = curComm["in_split"]
newComm.outSplit = curComm["out_split"]

if newComm.comms in supportedP2pOps:
newComm.src_rank = curComm["src_rank"]
newComm.dst_rank = curComm["dst_rank"]
newComm.batch_p2p = curComm["use_batch"]


def _parseBasicTraceCompute(curComm, newComm: commsArgs) -> None:
newComm.compute = curComm["compute"].lower()
if newComm.markerStack is None:
newComm.markerStack = [newComm.compute]
# count = number of times to call the compute kernel
if "count" in curComm:
newComm.count = curComm["count"]
# if no count is specified, assume 1
else:
newComm.count = 1
if newComm.compute == "gemm":
if "mm_dim" in curComm:
newComm.mm0_dim0 = curComm.get("mm_dim")
newComm.mm0_dim1 = curComm.get("mm_dim")
newComm.mm1_dim0 = curComm.get("mm_dim")
newComm.mm1_dim1 = curComm.get("mm_dim")
else:
newComm.mm0_dim0 = curComm.get("mm0_dim0")
newComm.mm0_dim1 = curComm.get("mm0_dim1")
newComm.mm1_dim0 = curComm.get("mm1_dim0")
newComm.mm1_dim1 = curComm.get("mm1_dim1")
newComm.dtype = curComm.get("dtype").lower()
elif newComm.compute == "emb_lookup":
if "direction" in curComm:
newComm.direction = curComm["direction"]
else:
newComm.direction = "forward"
newComm.emb_dim = curComm.get("emb_dim")
newComm.num_embs = curComm.get("num_embs")
newComm.batch_size = curComm.get("batch_size")
newComm.num_emb_tables_per_device = curComm.get("num_emb_tables")
newComm.num_emb_tables_batched = -1
newComm.bag_size = curComm.get("bag_size")
else:
raise ValueError(
f"Trace file contains {str(newComm.compute)} compute element that is not supported in PARAM!"
)


def _parseKinetoUnitrace(in_trace: List, target_rank: int) -> List:
"""
Convert the Kineto unitrace w/ comms metadata to the clean common trace format for replay.
"""
newCommsTrace = []
commsCnt = 0
for entry in in_trace:
# TODO: figure the current marker stack if present
marker = "unknown"
pass

if (
"name" in entry
and entry["name"] == "record_param_comms"
and entry["args"]["rank"] == target_rank
):

newComm = commsArgs()
newComm.comms = comms_utils.paramToCommName(entry["args"]["comms"].lower())
newComm.id = commsCnt
newComm.inMsgSize = entry["args"]["in_msg_size"]
newComm.outMsgSize = entry["args"]["out_msg_size"]
newComm.dtype = entry["args"]["dtype"].lower()
newComm.inSplit = entry["args"]["in_split"]
newComm.outSplit = entry["args"]["out_split"]
newComm.markerStack = marker

newCommsTrace.append(newComm)
commsCnt += 1

return newCommsTrace


def _getTensorInfoFromPyTorchETEntry(
tensor_container: List, container_type: str
) -> Tuple[int, int, str]:
"""
Extract message size, tensor count, type from PyTorch ET entry inputs/outputs field.
NOTE: This format can be changed at anytime. TODO: When an extract/parsing tool is available in ATC, switch to it.
"""
list_count = container_type.count("GenericList")
tensors = []
if list_count == 2:
# GenericList[GenericList[Tensor(), Tensor()]]
tensors = tensor_container[0][0]
dtype = container_type.replace("GenericList[", "").split(",", 1)[0]
elif list_count == 1:
# GenericList[Tensor()]
tensors = tensor_container[0]
dtype = container_type.replace("GenericList[", "").replace("]", "")
else:
tensors.append(tensor_container[0])
dtype = container_type

msg_size = 0
for tensor in tensors:
msg_size += tensor[3]

return msg_size, dtype


def _parseExecutionTrace(
in_trace: ExecutionTrace, target_rank: int, total_ranks: int
) -> List:
"""
Convert the Execution Trace comms metadata to the common trace format for replay.
"""
# Execution Trace PG_ID types availability
ET_PG_NAME_TUPLE = in_trace.schema_pytorch() >= (1, 0, 3)
ET_BACKENDID = in_trace.schema_pytorch() < (1, 0, 3)

initOps = []
newCommsTrace = []
backendIdToPgid = {}
pgRanksMap = {}
groupCnt = -1

# Parse PG info from ET
for node in in_trace.nodes.values():
if "process_group:init" in node.name:
pgJson = node.inputs[0]
try:
pgObj = json.loads(pgJson)
except json.decoder.JSONDecodeError: # skip if pg_config_info is truncated
break

for pg in pgObj:
if not pg["pg_name"].isdecimal():
# TODO support local synchronization pg
continue
pgId = int(pg["pg_name"])
ranks = pg["ranks"]
groupCnt = pg["group_count"]
pgRanksMap[pgId] = (
ranks
if len(ranks) > 0
else list(range(pg["group_size"]))
# rank list is empty when all ranks are in a pg
)
if ET_BACKENDID:
backendId = pg["uid"] if "uid" in pg else pg["backend_id"]
backendIdToPgid[backendId] = pgId
break # only one process_group init node per trace

# Parse comms nodes
for node in in_trace.nodes.values():
if node.name == "record_param_comms":
shift = (
0 if len(node.inputs) == 8 or len(node.inputs) == 10 else 1
) # wait/barrier ops do not have an input tensor (len=7), shift index one over
newComm = commsArgs()
newComm.id = node.id
newComm.comms = comms_utils.paramToCommName(
node.inputs[4 - shift].lower()
) # 5th value of inputs is colName
if newComm.comms == "init":
continue
newComm.req = node.inputs[
1 - shift
] # 2nd value of inputs is the req id of the collective

pgIdentifier = node.inputs[
2 - shift
] # 3rd value of inputs is the pg identifier of the collective
# Assign pg_id info for PGs that were created.
if ET_BACKENDID and pgIdentifier in backendIdToPgid:
newComm.pgId = backendIdToPgid[pgIdentifier]
newComm.groupRanks = pgRanksMap[newComm.pgId]
newComm.worldSize = len(newComm.groupRanks)
elif ET_PG_NAME_TUPLE and pgIdentifier[0].isdecimal():
newComm.pgId = int(pgIdentifier[0])
newComm.groupRanks = pgRanksMap[newComm.pgId]
newComm.worldSize = len(newComm.groupRanks)

if newComm.comms not in ("wait", "barrier"):
(
newComm.inMsgSize,
inMsgType,
) = _getTensorInfoFromPyTorchETEntry(node.inputs, node.input_types[0])
(
newComm.outMsgSize,
_,
) = _getTensorInfoFromPyTorchETEntry(node.outputs, node.output_types[0])
newComm.dtype = tensorDtypeMap[
inMsgType
] # 1st value of input_types is the data type for the tensors

if newComm.comms in supportedP2pOps:
if "send" in newComm.comms:
newComm.src_rank = target_rank
local_dst_rank = node.inputs[3 - shift]
newComm.dst_rank = newComm.groupRanks[local_dst_rank]
if "recv" in newComm.comms:
local_src_rank = node.inputs[3 - shift]
newComm.src_rank = newComm.groupRanks[local_src_rank]
newComm.dst_rank = target_rank

if newComm.comms == "broadcast":
newComm.root = newComm.groupRanks[0]
newComm.srcOrDst = newComm.groupRanks[0]

if newComm.comms == "all_to_allv":
# 6th value of inputs is in_split, split evenly if not provided
if not newComm.worldSize:
# if no pg info provided, use total ranks as world size
newComm.worldSize = total_ranks
newComm.inSplit = (
node.inputs[5]
if node.inputs[5]
else [int(newComm.inMsgSize / newComm.worldSize)]
* newComm.worldSize
)
# 7th value of inputs is out_split, split evenly if not provided
newComm.outSplit = (
node.inputs[6]
if node.inputs[6]
else [int(newComm.outMsgSize / newComm.worldSize)]
* newComm.worldSize
)
newCommsTrace.append(newComm)

# Build init node
initOps = []
if groupCnt < 0:
# old format: To be removed
for pgId, ranks in pgRanksMap.items():
newComm = create_pg_init_node(pgId, ranks, len(ranks))
initOps.append(newComm)
else:
for pgId in range(groupCnt):
if pgId in pgRanksMap:
ranks = pgRanksMap[pgId]
else:
# create a dummy pg that the current rank is not part of
ranks = [0] if target_rank != 0 else [1]

newComm = create_pg_init_node(pgId, ranks, len(ranks))
initOps.append(newComm)

return initOps + newCommsTrace


def create_pg_init_node(pg_id: int, ranks: List[int], world_size: int):
newComm = commsArgs()
newComm.comms = "init"
newComm.pgId = pg_id
newComm.req = -1
newComm.groupRanks = ranks
newComm.worldSize = world_size
return newComm
Loading

0 comments on commit e99ef20

Please sign in to comment.