diff --git a/python/rmm/rmm/_cuda/gpu.py b/python/rmm/rmm/_cuda/gpu.py index 2a23b41e6..4a5f43fd4 100644 --- a/python/rmm/rmm/_cuda/gpu.py +++ b/python/rmm/rmm/_cuda/gpu.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. from cuda import cuda, cudart @@ -78,15 +78,13 @@ def runtimeGetVersion(): The version is returned as (1000 major + 10 minor). For example, CUDA 9.2 would be represented by 9020. - This calls numba.cuda.runtime.get_version() rather than cuda-python due to - current limitations in cuda-python. + This function automatically raises CUDARuntimeError with error message + and status code. """ - # TODO: Replace this with `cuda.cudart.cudaRuntimeGetVersion()` when the - # limitation is fixed. - import numba.cuda - - major, minor = numba.cuda.runtime.get_version() - return major * 1000 + minor * 10 + status, version = cudart.getLocalRuntimeVersion() + if status != cudart.cudaError_t.cudaSuccess: + raise CUDARuntimeError(status) + return version def getDeviceCount():