-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Python Cuda loading issues #7939
Changes from all commits
473b641
de5ce6c
c2d171b
486ce43
42383d1
3f46fda
5465766
cbd3d68
0026ddd
645bed3
f8499a1
87cc758
27968ef
1a3a128
95d7c0f
d4b7b44
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -470,24 +470,33 @@ static void RegisterExecutionProviders(InferenceSession* sess, const std::vector | |
#endif | ||
} else if (type == kCudaExecutionProvider) { | ||
#ifdef USE_CUDA | ||
const auto it = provider_options_map.find(type); | ||
CUDAExecutionProviderInfo info{}; | ||
if (it != provider_options_map.end()) | ||
GetProviderInfo_CUDA()->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info); | ||
else { | ||
info.device_id = cuda_device_id; | ||
info.gpu_mem_limit = gpu_mem_limit; | ||
info.arena_extend_strategy = arena_extend_strategy; | ||
info.cudnn_conv_algo_search = cudnn_conv_algo_search; | ||
info.do_copy_in_default_stream = do_copy_in_default_stream; | ||
info.external_allocator_info = external_allocator_info; | ||
} | ||
if(auto* cuda_provider_info = TryGetProviderInfo_CUDA()) | ||
{ | ||
const auto it = provider_options_map.find(type); | ||
CUDAExecutionProviderInfo info{}; | ||
if (it != provider_options_map.end()) | ||
cuda_provider_info->CUDAExecutionProviderInfo__FromProviderOptions(it->second, info); | ||
else { | ||
info.device_id = cuda_device_id; | ||
info.gpu_mem_limit = gpu_mem_limit; | ||
info.arena_extend_strategy = arena_extend_strategy; | ||
info.cudnn_conv_algo_search = cudnn_conv_algo_search; | ||
info.do_copy_in_default_stream = do_copy_in_default_stream; | ||
info.external_allocator_info = external_allocator_info; | ||
} | ||
|
||
// This variable is never initialized because the APIs by which is it should be initialized are deprecated, however they still | ||
// exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can | ||
// since FromProviderOptions might contain external CUDA allocator. | ||
external_allocator_info = info.external_allocator_info; | ||
RegisterExecutionProvider(sess, *GetProviderInfo_CUDA()->CreateExecutionProviderFactory(info)); | ||
// This variable is never initialized because the APIs by which is it should be initialized are deprecated, however they still | ||
// exist are are in-use. Neverthless, it is used to return CUDAAllocator, hence we must try to initialize it here if we can | ||
// since FromProviderOptions might contain external CUDA allocator. | ||
external_allocator_info = info.external_allocator_info; | ||
RegisterExecutionProvider(sess, *cuda_provider_info->CreateExecutionProviderFactory(info)); | ||
} | ||
else | ||
{ | ||
if(!Env::Default().GetEnvironmentVar("CUDA_PATH").empty()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you please add a note here saying why we're checking for this? something like - we want to allow the flexibility to use a gpu pkg on a CPU-only machine unless the user intended to run inferencing on the GPU. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would be nice to add something else to the message for the user to check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea, error message updated |
||
ORT_THROW("CUDA_PATH is set but CUDA wasn't able to be loaded. Please install the correct version of CUDA and cuDNN as mentioned in the GPU requirements page, make sure they're in the PATH, and that your GPU is supported."); | ||
} | ||
} | ||
#endif | ||
} else if (type == kRocmExecutionProvider) { | ||
#ifdef USE_ROCM | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better error msg we can give to users? Even if the user enables the cuda provider, this might fail and the error msg is not clear why it failed. The user cannot take any action other than filing a github issue which increases support costs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
second this. we need to somehow propagate the actual error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I worked on improving the error messages, it used to show this on failure:
"2021-06-09 18:24:21.8575539 [E:onnxruntime:Default, provider_bridge_ort.cc:1000 onnxruntime::ProviderLibrary::Get] Failed to load library, error code: 126"
With these changes, now it will show this:
2021-06-09 18:54:17.7950730 [E:onnxruntime:Default, provider_bridge_ort.cc:1000 onnxruntime::ProviderLibrary::Get] LoadLibrary failed with error 126 "The specified module could not be found." when trying to load "C:\Code\Github\onnxruntime\build\Windows\Debug\Debug\onnxruntime\capi\onnxruntime_providers_cuda.dll"
so now on Windows it will translate the error message into text, vs just showing a number.