diff --git a/src/trace_link/kineto_operator.py b/src/trace_link/kineto_operator.py index 9048b907..998575a1 100644 --- a/src/trace_link/kineto_operator.py +++ b/src/trace_link/kineto_operator.py @@ -103,6 +103,7 @@ def is_cuda_launch_op(self) -> bool: "cudaMemcpyAsync", "cudaMemcpyFromSymbol", "cudaMemcpyToSymbol", + "cudaLaunchCooperativeKernel", } return self.category in cuda_launch_categories and self.name in cuda_launch_operations diff --git a/tests/trace_link/test_kineto_operator.py b/tests/trace_link/test_kineto_operator.py index cff11751..3015c251 100644 --- a/tests/trace_link/test_kineto_operator.py +++ b/tests/trace_link/test_kineto_operator.py @@ -48,23 +48,29 @@ def test_repr_method(sample_operator_data): ) assert repr(operator) == expected_repr -@pytest.mark.parametrize("category, name, expected", [ - ("cuda_driver", "cuLaunchKernel", True), - ("cuda_driver", "cuLaunchKernelEx", True), - ("cuda_driver", "cudaLaunchKernel", True), - ("cuda_driver", "cudaLaunchKernelExC", True), - ("cuda_runtime", "cuLaunchKernel", True), - ("cuda_runtime", "cuLaunchKernelEx", True), - ("cuda_runtime", "cudaLaunchKernel", True), - ("cuda_runtime", "cudaLaunchKernelExC", True), - ("cuda_runtime", "cudaMemcpy", True), - ("cuda_runtime", "cudaMemcpyAsync", True), - ("cuda_runtime", "cudaMemcpyFromSymbol", True), - ("cuda_runtime", "cudaMemcpyToSymbol", True), - ("cpu_op", "cudaLaunchKernel", False), - ("cuda_runtime", "someOtherOperation", False), - ("some_other_category", "cudaLaunchKernel", False) -]) + +@pytest.mark.parametrize( + "category, name, expected", + [ + ("cuda_driver", "cuLaunchKernel", True), + ("cuda_driver", "cuLaunchKernelEx", True), + ("cuda_driver", "cudaLaunchKernel", True), + ("cuda_driver", "cudaLaunchKernelExC", True), + ("cuda_driver", "cudaLaunchCooperativeKernel", True), + ("cuda_runtime", "cuLaunchKernel", True), + ("cuda_runtime", "cuLaunchKernelEx", True), + ("cuda_runtime", "cudaLaunchKernel", True), + ("cuda_runtime", "cudaLaunchKernelExC", True), + ("cuda_runtime", "cudaLaunchCooperativeKernel", True), + ("cuda_runtime", "cudaMemcpy", True), + ("cuda_runtime", "cudaMemcpyAsync", True), + ("cuda_runtime", "cudaMemcpyFromSymbol", True), + ("cuda_runtime", "cudaMemcpyToSymbol", True), + ("cpu_op", "cudaLaunchKernel", False), + ("cuda_runtime", "someOtherOperation", False), + ("some_other_category", "cudaLaunchKernel", False), + ], +) def test_is_cuda_launch_op(category, name, expected): """Test the is_cuda_launch_op method with various inputs.""" operator_data = {