diff --git a/src/trace_link/kineto_operator.py b/src/trace_link/kineto_operator.py index 12c7228a..2016d410 100644 --- a/src/trace_link/kineto_operator.py +++ b/src/trace_link/kineto_operator.py @@ -95,12 +95,14 @@ def is_cuda_launch_op(self) -> bool: """ cuda_launch_categories = {"cuda_runtime", "cuda_driver"} cuda_launch_operations = { + "cuLaunchKernel", + "cuLaunchKernelEx", "cudaLaunchKernel", "cudaLaunchKernelExC", "cudaMemcpy", "cudaMemcpyAsync", - "cudaMemcpyToSymbol", "cudaMemcpyFromSymbol", + "cudaMemcpyToSymbol", } return self.category in cuda_launch_categories and self.name in cuda_launch_operations diff --git a/tests/trace_link/test_kineto_operator.py b/tests/trace_link/test_kineto_operator.py index 0c3f131b..cff11751 100644 --- a/tests/trace_link/test_kineto_operator.py +++ b/tests/trace_link/test_kineto_operator.py @@ -47,3 +47,34 @@ def test_repr_method(sample_operator_data): "correlation=99)" ) assert repr(operator) == expected_repr + +@pytest.mark.parametrize("category, name, expected", [ + ("cuda_driver", "cuLaunchKernel", True), + ("cuda_driver", "cuLaunchKernelEx", True), + ("cuda_driver", "cudaLaunchKernel", True), + ("cuda_driver", "cudaLaunchKernelExC", True), + ("cuda_runtime", "cuLaunchKernel", True), + ("cuda_runtime", "cuLaunchKernelEx", True), + ("cuda_runtime", "cudaLaunchKernel", True), + ("cuda_runtime", "cudaLaunchKernelExC", True), + ("cuda_runtime", "cudaMemcpy", True), + ("cuda_runtime", "cudaMemcpyAsync", True), + ("cuda_runtime", "cudaMemcpyFromSymbol", True), + ("cuda_runtime", "cudaMemcpyToSymbol", True), + ("cpu_op", "cudaLaunchKernel", False), + ("cuda_runtime", "someOtherOperation", False), + ("some_other_category", "cudaLaunchKernel", False) +]) +def test_is_cuda_launch_op(category, name, expected): + """Test the is_cuda_launch_op method with various inputs.""" + operator_data = { + "cat": category, + "name": name, + "ph": "X", + "dur": 100, + "ts": 1590000000, + "tid": 1234, + "args": {"External id": "123", "Ev Idx": "456", "stream": 7, "Record function id": 12, "correlation": 99}, + } + operator = KinetoOperator(operator_data) + assert operator.is_cuda_launch_op() == expected