-
Notifications
You must be signed in to change notification settings - Fork 171
Make a few memory management objects public + Miscellaneous doc updates #693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7066082
a6387a0
d8e0f08
cc0f6ce
58323ac
5abf37b
7f34922
f65b644
f65c54c
c0db6b3
bbc1c65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,11 +2,12 @@ | |
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Union | ||
|
||
from cuda.core.experimental._kernel_arg_handler import ParamHolder | ||
from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config | ||
from cuda.core.experimental._module import Kernel | ||
from cuda.core.experimental._stream import Stream | ||
from cuda.core.experimental._stream import IsStreamT, Stream, _try_to_get_stream_ptr | ||
from cuda.core.experimental._utils.clear_error_support import assert_type | ||
from cuda.core.experimental._utils.cuda_utils import ( | ||
_reduce_3_tuple, | ||
|
@@ -34,7 +35,7 @@ def _lazy_init(): | |
_inited = True | ||
|
||
|
||
def launch(stream, config, kernel, *kernel_args): | ||
def launch(stream: Union[Stream, IsStreamT], config: LaunchConfig, kernel: Kernel, *kernel_args): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a bit inconsistent where only this API supports Maybe push this to a follow up PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have public examples for PyTorch and CuPy showcasing the use of this protocol. With this PR we made it slightly faster in favor of our own |
||
"""Launches a :obj:`~_module.Kernel` | ||
object with launch-time configuration. | ||
|
||
|
@@ -43,7 +44,7 @@ def launch(stream, config, kernel, *kernel_args): | |
stream : :obj:`~_stream.Stream` | ||
The stream establishing the stream ordering semantic of a | ||
launch. | ||
config : :obj:`~_launcher.LaunchConfig` | ||
config : :obj:`LaunchConfig` | ||
Launch configurations inline with options provided by | ||
:obj:`~_launcher.LaunchConfig` dataclass. | ||
kernel : :obj:`~_module.Kernel` | ||
|
@@ -55,13 +56,15 @@ def launch(stream, config, kernel, *kernel_args): | |
""" | ||
if stream is None: | ||
raise ValueError("stream cannot be None, stream must either be a Stream object or support __cuda_stream__") | ||
if not isinstance(stream, Stream): | ||
try: | ||
stream_handle = stream.handle | ||
except AttributeError: | ||
try: | ||
stream = Stream._init(stream) | ||
except Exception as e: | ||
stream_handle = _try_to_get_stream_ptr(stream) | ||
except Exception: | ||
raise ValueError( | ||
f"stream must either be a Stream object or support __cuda_stream__ (got {type(stream)})" | ||
) from e | ||
) from None | ||
assert_type(kernel, Kernel) | ||
_lazy_init() | ||
config = check_or_create_options(LaunchConfig, config, "launch config") | ||
|
@@ -78,15 +81,15 @@ def launch(stream, config, kernel, *kernel_args): | |
# rich. | ||
if _use_ex: | ||
drv_cfg = _to_native_launch_config(config) | ||
drv_cfg.hStream = stream.handle | ||
drv_cfg.hStream = stream_handle | ||
if config.cooperative_launch: | ||
_check_cooperative_launch(kernel, config, stream) | ||
handle_return(driver.cuLaunchKernelEx(drv_cfg, int(kernel._handle), args_ptr, 0)) | ||
else: | ||
# TODO: check if config has any unsupported attrs | ||
handle_return( | ||
driver.cuLaunchKernel( | ||
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream.handle, args_ptr, 0 | ||
int(kernel._handle), *config.grid, *config.block, config.shmem_size, stream_handle, args_ptr, 0 | ||
) | ||
) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.