Skip to content

Commit

Permalink
Add missing keyword in is_cuda_launch_op
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jul 10, 2024
1 parent 7f8b892 commit 8b2ddfc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
1 change: 1 addition & 0 deletions src/trace_link/kineto_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def is_cuda_launch_op(self) -> bool:
"cudaMemcpyAsync",
"cudaMemcpyFromSymbol",
"cudaMemcpyToSymbol",
"cudaLaunchCooperativeKernel",
}
return self.category in cuda_launch_categories and self.name in cuda_launch_operations

Expand Down
40 changes: 23 additions & 17 deletions tests/trace_link/test_kineto_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,23 +48,29 @@ def test_repr_method(sample_operator_data):
)
assert repr(operator) == expected_repr

@pytest.mark.parametrize("category, name, expected", [
("cuda_driver", "cuLaunchKernel", True),
("cuda_driver", "cuLaunchKernelEx", True),
("cuda_driver", "cudaLaunchKernel", True),
("cuda_driver", "cudaLaunchKernelExC", True),
("cuda_runtime", "cuLaunchKernel", True),
("cuda_runtime", "cuLaunchKernelEx", True),
("cuda_runtime", "cudaLaunchKernel", True),
("cuda_runtime", "cudaLaunchKernelExC", True),
("cuda_runtime", "cudaMemcpy", True),
("cuda_runtime", "cudaMemcpyAsync", True),
("cuda_runtime", "cudaMemcpyFromSymbol", True),
("cuda_runtime", "cudaMemcpyToSymbol", True),
("cpu_op", "cudaLaunchKernel", False),
("cuda_runtime", "someOtherOperation", False),
("some_other_category", "cudaLaunchKernel", False)
])

@pytest.mark.parametrize(
"category, name, expected",
[
("cuda_driver", "cuLaunchKernel", True),
("cuda_driver", "cuLaunchKernelEx", True),
("cuda_driver", "cudaLaunchKernel", True),
("cuda_driver", "cudaLaunchKernelExC", True),
("cuda_driver", "cudaLaunchCooperativeKernel", True),
("cuda_runtime", "cuLaunchKernel", True),
("cuda_runtime", "cuLaunchKernelEx", True),
("cuda_runtime", "cudaLaunchKernel", True),
("cuda_runtime", "cudaLaunchKernelExC", True),
("cuda_runtime", "cudaLaunchCooperativeKernel", True),
("cuda_runtime", "cudaMemcpy", True),
("cuda_runtime", "cudaMemcpyAsync", True),
("cuda_runtime", "cudaMemcpyFromSymbol", True),
("cuda_runtime", "cudaMemcpyToSymbol", True),
("cpu_op", "cudaLaunchKernel", False),
("cuda_runtime", "someOtherOperation", False),
("some_other_category", "cudaLaunchKernel", False),
],
)
def test_is_cuda_launch_op(category, name, expected):
"""Test the is_cuda_launch_op method with various inputs."""
operator_data = {
Expand Down

0 comments on commit 8b2ddfc

Please sign in to comment.