diff --git a/src/trace_link/trace_linker.py b/src/trace_link/trace_linker.py index cf45799e..8f1fe894 100644 --- a/src/trace_link/trace_linker.py +++ b/src/trace_link/trace_linker.py @@ -881,17 +881,34 @@ def link_ops( inclusive_dur = kineto_op.inclusive_dur exclusive_dur = kineto_op.exclusive_dur timestamp = kineto_op.timestamp - inter_thread_dep = None - if kineto_op.inter_thread_dep: - inter_thread_dep_kineto_op = kineto_rf_id_to_kineto_op_map[kineto_op.inter_thread_dep] - if inter_thread_dep_kineto_op.pytorch_op: - inter_thread_dep = inter_thread_dep_kineto_op.pytorch_op.id + inter_thread_dep = self.get_inter_thread_dep(kineto_op, kineto_rf_id_to_kineto_op_map) self.link_gpu_ops(pytorch_op, linked_gpu_ops) return linked_gpu_ops, inclusive_dur, exclusive_dur, timestamp, inter_thread_dep + def get_inter_thread_dep(self, kineto_op, kineto_rf_id_to_kineto_op_map): + """ + Retrieve the inter-thread dependency ID for a given Kineto operator. + + This method finds the corresponding PyTorch operator ID for the inter-thread dependency + if it exists. + + Args: + kineto_op (KinetoOperator): The Kineto operator being processed. + kineto_rf_id_to_kineto_op_map (Dict[int, KinetoOperator]): Mapping from rf_id to Kineto operators. + + Returns: + Optional[int]: The PyTorch operator ID for the inter-thread dependency if it exists, + otherwise None. + """ + if kineto_op.inter_thread_dep: + inter_thread_dep_kineto_op = kineto_rf_id_to_kineto_op_map[kineto_op.inter_thread_dep] + if inter_thread_dep_kineto_op.pytorch_op: + return inter_thread_dep_kineto_op.pytorch_op.id + return None + def link_gpu_ops(self, pytorch_op: PyTorchOperator, kineto_gpu_ops: List[KinetoOperator]) -> None: """ Link GPU operators to a PyTorch operator. diff --git a/tests/trace_link/test_trace_linker.py b/tests/trace_link/test_trace_linker.py index d1117a22..bf3e595a 100644 --- a/tests/trace_link/test_trace_linker.py +++ b/tests/trace_link/test_trace_linker.py @@ -142,6 +142,28 @@ def test_find_last_cpu_node_before_timestamp(ops_by_tid, exclude_tid, timestamp, assert result == expected_result +def test_get_inter_thread_dep(trace_linker): + kineto_op = MagicMock(spec=KinetoOperator) + kineto_op.inter_thread_dep = 1 + inter_thread_dep_kineto_op = MagicMock(spec=KinetoOperator) + inter_thread_dep_kineto_op.pytorch_op = MagicMock(id=42) + + kineto_rf_id_to_kineto_op_map = {1: inter_thread_dep_kineto_op} + + result = trace_linker.get_inter_thread_dep(kineto_op, kineto_rf_id_to_kineto_op_map) + assert result == 42 + + +def test_get_inter_thread_dep_none(trace_linker): + kineto_op = MagicMock(spec=KinetoOperator) + kineto_op.inter_thread_dep = None + + kineto_rf_id_to_kineto_op_map = {} + + result = trace_linker.get_inter_thread_dep(kineto_op, kineto_rf_id_to_kineto_op_map) + assert result is None + + def test_link_gpu_ops(trace_linker): # Create a mock PyTorch operator pytorch_op = MagicMock(spec=PyTorchOperator)