Skip to content

Commit

Permalink
Leverage HTA to support synchronization dependency in trace linking
Browse files Browse the repository at this point in the history
Co-authored-by: Joongun Park <[email protected]>
  • Loading branch information
TaekyungHeo and JoongunPark committed Oct 7, 2024
1 parent a01f265 commit aacfcff
Show file tree
Hide file tree
Showing 9 changed files with 456 additions and 9 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ jobs:
git checkout 7b19f586dd8b267333114992833a0d7e0d601630
pip install .
- name: Install HTA
run: |
git clone https://github.com/facebookresearch/HolisticTraceAnalysis.git
cd HolisticTraceAnalysis
git checkout d731cc2e2249976c97129d409a83bd53d93051f6
git submodule update --init
pip install -r requirements.txt
pip install -e .
- name: Install Dependencies
run: |
pip install -r requirements-dev.txt
Expand Down
12 changes: 10 additions & 2 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,23 @@ def convert_json_to_protobuf_nodes(
[
ChakraAttr(name="comm_type", int64_val=collective_comm_type),
ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size),
*( [ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)] if pytorch_gpu_node.pg_name != "" else [] ),
*(
[ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)]
if pytorch_gpu_node.pg_name != ""
else []
),
]
)

elif chakra_gpu_node.type in {COMM_SEND_NODE, COMM_RECV_NODE}:
chakra_gpu_node.attr.extend(
[
ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size),
*( [ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)] if pytorch_gpu_node.pg_name != "" else [] ),
*(
[ChakraAttr(name="pg_name", string_val=pytorch_gpu_node.pg_name)]
if pytorch_gpu_node.pg_name != ""
else []
),
]
)

Expand Down
2 changes: 1 addition & 1 deletion src/converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class PyTorchNodeType(Enum):
CPU_OP = 1
GPU_OP = 2
LABEL = 3 # Non-operator nodes
METADATA = 4 # Metadata nodes
METADATA = 4 # Metadata nodes


class PyTorchNode:
Expand Down
13 changes: 13 additions & 0 deletions src/trace_link/chakra_device_trace_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def load(
) -> Tuple[
List[KinetoOperator],
Dict[int, List[KinetoOperator]],
Dict[int, List[KinetoOperator]],
Dict[int, KinetoOperator],
List[KinetoOperator],
Dict[int, KinetoOperator],
Expand All @@ -26,6 +27,7 @@ def load(
Dict[int, KinetoOperator],
List[KinetoOperator],
List[int],
Dict[int, KinetoOperator],
]:
"""
Load and process the Chakra device trace.
Expand Down Expand Up @@ -57,6 +59,7 @@ def load(
logging.debug("Chakra device trace has been loaded and processed successfully.")
return (
dev_data["kineto_cpu_ops"],
dev_data["kineto_tid_ops_map"],
dev_data["kineto_tid_cpu_ops_map"],
dev_data["kineto_correlation_cuda_runtime_map"],
dev_data["kineto_gpu_ops"],
Expand All @@ -68,6 +71,7 @@ def load(
dev_data["kineto_rf_id_to_kineto_op_map"],
dev_data["sorted_kineto_cpu_ops"],
dev_data["sorted_kineto_cpu_op_ts"],
dev_data["kineto_external_id_to_kineto_op_map"],
)

def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_file: str) -> Dict:
Expand All @@ -90,13 +94,17 @@ def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_
thread_info = {}

kineto_cpu_ops = []
kineto_tid_ops_map = {}
kineto_tid_cpu_ops_map = {}
kineto_correlation_cuda_runtime_map = {}
kineto_gpu_ops = []
kineto_id_arrow_op_map = {}
kineto_id_cuda_launch_op_map = {}
kineto_external_id_to_kineto_op_map = {}

for op in kineto_ops:
kineto_tid_ops_map.setdefault(op.tid, []).append(op)

if op.is_cpu_op():
kineto_cpu_ops.append(op)
kineto_tid_cpu_ops_map.setdefault(op.tid, []).append(op)
Expand Down Expand Up @@ -144,10 +152,14 @@ def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_
thread_start_end[0] = min(thread_start_end[0], op.timestamp)
thread_start_end[1] = max(thread_start_end[1], op.timestamp + op.inclusive_dur)

if op.external_id is not None:
kineto_external_id_to_kineto_op_map[op.external_id] = op

kineto_rf_id_to_kineto_op_map = {op.rf_id: op for op in kineto_cpu_ops if op.rf_id is not None}

return {
"kineto_cpu_ops": kineto_cpu_ops,
"kineto_tid_ops_map": kineto_tid_ops_map,
"kineto_tid_cpu_ops_map": kineto_tid_cpu_ops_map,
"kineto_correlation_cuda_runtime_map": kineto_correlation_cuda_runtime_map,
"kineto_gpu_ops": kineto_gpu_ops,
Expand All @@ -159,6 +171,7 @@ def construct_dev_data_structures(self, kineto_ops: List[KinetoOperator], trace_
"kineto_rf_id_to_kineto_op_map": kineto_rf_id_to_kineto_op_map,
"sorted_kineto_cpu_ops": [],
"sorted_kineto_cpu_op_ts": [],
"kineto_external_id_to_kineto_op_map": kineto_external_id_to_kineto_op_map,
}

def calculate_exclusive_dur(self, kineto_tid_cpu_ops_map: Dict[int, List[KinetoOperator]]) -> None:
Expand Down
7 changes: 5 additions & 2 deletions src/trace_link/kineto_operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from et_replay.execution_trace import Node as PyTorchOperator

Expand All @@ -22,6 +22,7 @@ class KinetoOperator:
host_op (Optional[PyTorchOperator]): Corresponding PyTorch operator object.
parent_host_op_id (Optional[int]): ID of the parent PyTorch operator.
inter_thread_dep (Optional[int]): Identifier for inter-thread dependencies.
sync_dep (List[KinetoOperator]): List of KinetoOperator objects that have dependencies on this operator.
stream (Optional[int]): CUDA stream identifier associated with the operator.
rf_id (Optional[int]): Record function identifier.
correlation (int): Identifier used to correlate CUDA runtime and GPU operations.
Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, kineto_op: Dict[str, Any]) -> None:
self.host_op: Optional[PyTorchOperator] = None
self.parent_host_op_id: Optional[int] = None
self.inter_thread_dep: Optional[int] = None
self.sync_dep: List[KinetoOperator] = []
self.stream: Optional[int] = kineto_op.get("args", {}).get("stream", None)
self.rf_id: Optional[int] = kineto_op.get("args", {}).get("Record function id", None)
self.correlation: int = kineto_op.get("args", {}).get("correlation", -1)
Expand All @@ -61,13 +63,14 @@ def __repr__(self) -> str:
Returns
str: A string representation of the KinetoOperator.
"""
sync_dep_ids = [op.id for op in self.sync_dep]
return (
f"KinetoOperator(id={self.id}, category={self.category}, name={self.name}, "
f"phase={self.phase}, inclusive_dur={self.inclusive_dur}, "
f"exclusive_dur={self.exclusive_dur}, timestamp={self.timestamp}, "
f"external_id={self.external_id}, ev_idx={self.ev_idx}, tid={self.tid}, "
f"parent_host_op_id={self.parent_host_op_id}, inter_thread_dep={self.inter_thread_dep}, "
f"stream={self.stream}, rf_id={self.rf_id}, correlation={self.correlation})"
f"sync_dep={sync_dep_ids}, stream={self.stream}, rf_id={self.rf_id}, correlation={self.correlation})"
)

def is_cpu_op(self) -> bool:
Expand Down
4 changes: 3 additions & 1 deletion src/trace_link/trace_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def main() -> None:
"Merging-PyTorch-and-Kineto-Traces"
)
)
parser.add_argument("--rank", type=int, required=True, help="Rank for the input traces")
parser.add_argument(
"--chakra-host-trace",
type=str,
Expand All @@ -43,10 +44,11 @@ def main() -> None:
logging.basicConfig(level=args.log_level.upper())

linker = TraceLinker()
linker.link(args.chakra_host_trace, args.chakra_device_trace, args.output_file)
linker.link(args.rank, args.chakra_host_trace, args.chakra_device_trace, args.output_file)

logging.info(f"Linking process successful. Output file is available at {args.output_file}.")
logging.info("Please run the chakra_converter for further postprocessing.")


if __name__ == "__main__":
main()
Loading

0 comments on commit aacfcff

Please sign in to comment.