Skip to content

Commit

Permalink
Add keywords for CUDA runtime ops in is_cuda_launch_op method
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jul 2, 2024
1 parent 42d30f8 commit f4eaf47
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/trace_link/kineto_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,14 @@ def is_cuda_launch_op(self) -> bool:
"""
cuda_launch_categories = {"cuda_runtime", "cuda_driver"}
cuda_launch_operations = {
"cuLaunchKernel",
"cuLaunchKernelEx",
"cudaLaunchKernel",
"cudaLaunchKernelExC",
"cudaMemcpy",
"cudaMemcpyAsync",
"cudaMemcpyToSymbol",
"cudaMemcpyFromSymbol",
"cudaMemcpyToSymbol",
}
return self.category in cuda_launch_categories and self.name in cuda_launch_operations

Expand Down
31 changes: 31 additions & 0 deletions tests/trace_link/test_kineto_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,34 @@ def test_repr_method(sample_operator_data):
"correlation=99)"
)
assert repr(operator) == expected_repr

@pytest.mark.parametrize("category, name, expected", [
("cuda_driver", "cuLaunchKernel", True),
("cuda_driver", "cuLaunchKernelEx", True),
("cuda_driver", "cudaLaunchKernel", True),
("cuda_driver", "cudaLaunchKernelExC", True),
("cuda_runtime", "cuLaunchKernel", True),
("cuda_runtime", "cuLaunchKernelEx", True),
("cuda_runtime", "cudaLaunchKernel", True),
("cuda_runtime", "cudaLaunchKernelExC", True),
("cuda_runtime", "cudaMemcpy", True),
("cuda_runtime", "cudaMemcpyAsync", True),
("cuda_runtime", "cudaMemcpyFromSymbol", True),
("cuda_runtime", "cudaMemcpyToSymbol", True),
("cpu_op", "cudaLaunchKernel", False),
("cuda_runtime", "someOtherOperation", False),
("some_other_category", "cudaLaunchKernel", False)
])
def test_is_cuda_launch_op(category, name, expected):
"""Test the is_cuda_launch_op method with various inputs."""
operator_data = {
"cat": category,
"name": name,
"ph": "X",
"dur": 100,
"ts": 1590000000,
"tid": 1234,
"args": {"External id": "123", "Ev Idx": "456", "stream": 7, "Record function id": 12, "correlation": 99},
}
operator = KinetoOperator(operator_data)
assert operator.is_cuda_launch_op() == expected

0 comments on commit f4eaf47

Please sign in to comment.