Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCT/CUDA: Runtime CUDA >= 12.3 to enable VMM #10396

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions src/uct/cuda/cuda_copy/cuda_copy_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ static ucs_config_field_t uct_cuda_copy_md_config_table[] = {
{NULL}
};

static CUresult (*uct_cuda_cuCtxSetFlags_func)(unsigned);

static int uct_cuda_copy_md_is_dmabuf_supported()
{
int dmabuf_supported = 0;
Expand Down Expand Up @@ -515,22 +517,32 @@ static size_t uct_cuda_copy_md_get_total_device_mem(CUdevice cuda_device)
static void
uct_cuda_copy_sync_memops(uct_cuda_copy_md_t *md, const void *address)
{
unsigned sync_memops_value = 1;

#if HAVE_CUDA_FABRIC
ucs_status_t status;
if (!md->sync_memops_set) {
/* Synchronize future DMA operations for all memory types */
status = UCT_CUDADRV_FUNC_LOG_WARN(cuCtxSetFlags(CU_CTX_SYNC_MEMOPS));
if (status == UCS_OK) {
md->sync_memops_set = 1;

if (uct_cuda_cuCtxSetFlags_func != NULL) {
if (!md->sync_memops_set) {
/* Synchronize future DMA operations for all memory types */
status = UCT_CUDADRV_FUNC_LOG_WARN(
uct_cuda_cuCtxSetFlags_func(CU_CTX_SYNC_MEMOPS));
if (status == UCS_OK) {
md->sync_memops_set = 1;
}
}

return;
}
#else
unsigned value = 1;
(void)uct_cuda_cuCtxSetFlags_func;
#endif

/* Synchronize for DMA for legacy memory types*/
UCT_CUDADRV_FUNC_LOG_WARN(
cuPointerSetAttribute(&value, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS,
cuPointerSetAttribute(&sync_memops_value,
CU_POINTER_ATTRIBUTE_SYNC_MEMOPS,
(CUdeviceptr)address));
#endif
}

static ucs_status_t
Expand Down Expand Up @@ -823,6 +835,33 @@ static uct_md_ops_t md_ops = {
.detect_memory_type = uct_cuda_copy_md_detect_memory_type
};

static ucs_status_t uct_cuda_copy_md_check_is_ctx_set_flags_supported(void)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify the code, we could have this function call the needed function pointer, and move the global var inside it.
Something like
ucs_status_t uct_cuda_copy_set_ctx_flags(unsigned flags)
and have it return UCS_ERR_UNSUPPORTED if the func pointer is not found.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about it but went for two step approach as we need:

  1. disable fabric at init time
  2. set the flag with md and address as parameter, in case we cannot use cuCtxSetFlags()

{
static ucs_status_t status = UCS_ERR_LAST;
rakhmets marked this conversation as resolved.
Show resolved Hide resolved

#if CUDA_VERSION >= 12000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuGetProcAddress() prototype changed at >=12000 and we know that cuCtxSetFlags() also appeared after 12000 so no need to use older cuGetProcAddress() prototype to check.

CUdriverProcAddressQueryResult sym_status;
CUresult cu_err;

if (status == UCS_ERR_LAST) {
cu_err = cuGetProcAddress("cuCtxSetFlags",
(void**)&uct_cuda_cuCtxSetFlags_func,
12010, CU_GET_PROC_ADDRESS_DEFAULT,
&sym_status);

if ((cu_err == CUDA_SUCCESS) &&
(sym_status == CU_GET_PROC_ADDRESS_SUCCESS)) {
status = UCS_OK;
} else {
uct_cuda_cuCtxSetFlags_func = NULL;
status = UCS_ERR_UNSUPPORTED;
}
}
#endif

return status;
}

static ucs_status_t
uct_cuda_copy_md_open(uct_component_t *component, const char *md_name,
const uct_md_config_t *md_config, uct_md_h *md_p)
Expand Down Expand Up @@ -850,6 +889,20 @@ uct_cuda_copy_md_open(uct_component_t *component, const char *md_name,
md->sync_memops_set = 0;
md->granularity = SIZE_MAX;

status = uct_cuda_copy_md_check_is_ctx_set_flags_supported();
if ((status != UCS_OK) && (md->config.enable_fabric != UCS_NO)) {
if (md->config.enable_fabric == UCS_YES) {
ucs_error("failed to enable fabric memory allocations as cuda "
"driver library does not support cuCtxSetFlags()");
goto err_free_md;
}

ucs_diag("disabled fabric memory allocations as cuda driver library "
"does not support cuCtxSetFlags()");
rakhmets marked this conversation as resolved.
Show resolved Hide resolved

md->config.enable_fabric = UCS_NO;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like it affects only cuda_copy memory allocations, but what happens if we get a fabric memory from user buffer and then we don't actually set sync memops for it?
we could return UNSUPPORTED from uct_cuda_copy_sync_memops and if not - return error from cuda memory detection

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should show now be handled right?

}

if ((config->cuda_async_mem_type != UCS_MEMORY_TYPE_CUDA) &&
(config->cuda_async_mem_type != UCS_MEMORY_TYPE_CUDA_MANAGED)) {
ucs_warn("wrong memory type for async memory allocations: \"%s\";"
Expand Down
Loading