Skip to content

Commit

Permalink
Extract get_inter_thread_dep method to simplify link_ops
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jul 12, 2024
1 parent 7f8b892 commit b40359c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
27 changes: 22 additions & 5 deletions src/trace_link/trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,17 +881,34 @@ def link_ops(
inclusive_dur = kineto_op.inclusive_dur
exclusive_dur = kineto_op.exclusive_dur
timestamp = kineto_op.timestamp
inter_thread_dep = None

if kineto_op.inter_thread_dep:
inter_thread_dep_kineto_op = kineto_rf_id_to_kineto_op_map[kineto_op.inter_thread_dep]
if inter_thread_dep_kineto_op.pytorch_op:
inter_thread_dep = inter_thread_dep_kineto_op.pytorch_op.id
inter_thread_dep = self.get_inter_thread_dep(kineto_op, kineto_rf_id_to_kineto_op_map)

self.link_gpu_ops(pytorch_op, linked_gpu_ops)

return linked_gpu_ops, inclusive_dur, exclusive_dur, timestamp, inter_thread_dep

def get_inter_thread_dep(self, kineto_op, kineto_rf_id_to_kineto_op_map):
"""
Retrieve the inter-thread dependency ID for a given Kineto operator.
This method finds the corresponding PyTorch operator ID for the inter-thread dependency
if it exists.
Args:
kineto_op (KinetoOperator): The Kineto operator being processed.
kineto_rf_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping from rf_id to Kineto operators.
Returns:
Optional[int]: The PyTorch operator ID for the inter-thread dependency if it exists,
otherwise None.
"""
if kineto_op.inter_thread_dep:
inter_thread_dep_kineto_op = kineto_rf_id_to_kineto_op_map[kineto_op.inter_thread_dep]
if inter_thread_dep_kineto_op.pytorch_op:
return inter_thread_dep_kineto_op.pytorch_op.id
return None

def link_gpu_ops(self, pytorch_op: PyTorchOperator, kineto_gpu_ops: List[KinetoOperator]) -> None:
"""
Link GPU operators to a PyTorch operator.
Expand Down
22 changes: 22 additions & 0 deletions tests/trace_link/test_trace_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,28 @@ def test_find_last_cpu_node_before_timestamp(ops_by_tid, exclude_tid, timestamp,
assert result == expected_result


def test_get_inter_thread_dep(trace_linker):
kineto_op = MagicMock(spec=KinetoOperator)
kineto_op.inter_thread_dep = 1
inter_thread_dep_kineto_op = MagicMock(spec=KinetoOperator)
inter_thread_dep_kineto_op.pytorch_op = MagicMock(id=42)

kineto_rf_id_to_kineto_op_map = {1: inter_thread_dep_kineto_op}

result = trace_linker.get_inter_thread_dep(kineto_op, kineto_rf_id_to_kineto_op_map)
assert result == 42


def test_get_inter_thread_dep_none(trace_linker):
kineto_op = MagicMock(spec=KinetoOperator)
kineto_op.inter_thread_dep = None

kineto_rf_id_to_kineto_op_map = {}

result = trace_linker.get_inter_thread_dep(kineto_op, kineto_rf_id_to_kineto_op_map)
assert result is None


def test_link_gpu_ops(trace_linker):
# Create a mock PyTorch operator
pytorch_op = MagicMock(spec=PyTorchOperator)
Expand Down

0 comments on commit b40359c

Please sign in to comment.