-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
UCT/CUDA: Runtime CUDA >= 12.3 to enable VMM #10396
base: master
Are you sure you want to change the base?
Changes from 8 commits
68a5f51
9fc4430
e8c9f99
3b43d29
2161adf
6563253
ff4313c
2f5e5a5
f1601a3
8657d54
0c27f31
eb0d1fc
fe0370b
ab3d0c7
a0004c4
81d47f0
edc0028
078a6cc
4ace4b1
0c39faa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,6 +81,8 @@ static ucs_config_field_t uct_cuda_copy_md_config_table[] = { | |
{NULL} | ||
}; | ||
|
||
static CUresult (*uct_cuda_cuCtxSetFlags_func)(unsigned); | ||
|
||
static int uct_cuda_copy_md_is_dmabuf_supported() | ||
{ | ||
int dmabuf_supported = 0; | ||
|
@@ -515,22 +517,32 @@ static size_t uct_cuda_copy_md_get_total_device_mem(CUdevice cuda_device) | |
static void | ||
uct_cuda_copy_sync_memops(uct_cuda_copy_md_t *md, const void *address) | ||
{ | ||
unsigned sync_memops_value = 1; | ||
|
||
#if HAVE_CUDA_FABRIC | ||
ucs_status_t status; | ||
if (!md->sync_memops_set) { | ||
/* Synchronize future DMA operations for all memory types */ | ||
status = UCT_CUDADRV_FUNC_LOG_WARN(cuCtxSetFlags(CU_CTX_SYNC_MEMOPS)); | ||
if (status == UCS_OK) { | ||
md->sync_memops_set = 1; | ||
|
||
if (uct_cuda_cuCtxSetFlags_func != NULL) { | ||
if (!md->sync_memops_set) { | ||
/* Synchronize future DMA operations for all memory types */ | ||
status = UCT_CUDADRV_FUNC_LOG_WARN( | ||
uct_cuda_cuCtxSetFlags_func(CU_CTX_SYNC_MEMOPS)); | ||
if (status == UCS_OK) { | ||
md->sync_memops_set = 1; | ||
} | ||
} | ||
|
||
return; | ||
} | ||
#else | ||
unsigned value = 1; | ||
(void)uct_cuda_cuCtxSetFlags_func; | ||
#endif | ||
|
||
/* Synchronize for DMA for legacy memory types*/ | ||
UCT_CUDADRV_FUNC_LOG_WARN( | ||
cuPointerSetAttribute(&value, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, | ||
cuPointerSetAttribute(&sync_memops_value, | ||
CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, | ||
(CUdeviceptr)address)); | ||
#endif | ||
} | ||
|
||
static ucs_status_t | ||
|
@@ -823,6 +835,33 @@ static uct_md_ops_t md_ops = { | |
.detect_memory_type = uct_cuda_copy_md_detect_memory_type | ||
}; | ||
|
||
static ucs_status_t uct_cuda_copy_md_check_is_ctx_set_flags_supported(void) | ||
{ | ||
static ucs_status_t status = UCS_ERR_LAST; | ||
rakhmets marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#if CUDA_VERSION >= 12000 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
CUdriverProcAddressQueryResult sym_status; | ||
CUresult cu_err; | ||
|
||
if (status == UCS_ERR_LAST) { | ||
cu_err = cuGetProcAddress("cuCtxSetFlags", | ||
(void**)&uct_cuda_cuCtxSetFlags_func, | ||
12010, CU_GET_PROC_ADDRESS_DEFAULT, | ||
&sym_status); | ||
|
||
if ((cu_err == CUDA_SUCCESS) && | ||
(sym_status == CU_GET_PROC_ADDRESS_SUCCESS)) { | ||
status = UCS_OK; | ||
} else { | ||
uct_cuda_cuCtxSetFlags_func = NULL; | ||
status = UCS_ERR_UNSUPPORTED; | ||
} | ||
} | ||
#endif | ||
|
||
return status; | ||
} | ||
|
||
static ucs_status_t | ||
uct_cuda_copy_md_open(uct_component_t *component, const char *md_name, | ||
const uct_md_config_t *md_config, uct_md_h *md_p) | ||
|
@@ -850,6 +889,20 @@ uct_cuda_copy_md_open(uct_component_t *component, const char *md_name, | |
md->sync_memops_set = 0; | ||
md->granularity = SIZE_MAX; | ||
|
||
status = uct_cuda_copy_md_check_is_ctx_set_flags_supported(); | ||
if ((status != UCS_OK) && (md->config.enable_fabric != UCS_NO)) { | ||
if (md->config.enable_fabric == UCS_YES) { | ||
ucs_error("failed to enable fabric memory allocations as cuda " | ||
"driver library does not support cuCtxSetFlags()"); | ||
goto err_free_md; | ||
} | ||
|
||
ucs_diag("disabled fabric memory allocations as cuda driver library " | ||
"does not support cuCtxSetFlags()"); | ||
rakhmets marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
md->config.enable_fabric = UCS_NO; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like it affects only cuda_copy memory allocations, but what happens if we get a fabric memory from user buffer and then we don't actually set sync memops for it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should show now be handled right? |
||
} | ||
|
||
if ((config->cuda_async_mem_type != UCS_MEMORY_TYPE_CUDA) && | ||
(config->cuda_async_mem_type != UCS_MEMORY_TYPE_CUDA_MANAGED)) { | ||
ucs_warn("wrong memory type for async memory allocations: \"%s\";" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To simplify the code, we could have this function call the needed function pointer, and move the global var inside it.
Something like
ucs_status_t uct_cuda_copy_set_ctx_flags(unsigned flags)
and have it return UCS_ERR_UNSUPPORTED if the func pointer is not found.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about it but went for two step approach as we need:
md
andaddress
as parameter, in case we cannot usecuCtxSetFlags()