diff --git a/cuda_core/cuda/core/experimental/_linker.py b/cuda_core/cuda/core/experimental/_linker.py index 6e36a2a5..6bad7d49 100644 --- a/cuda_core/cuda/core/experimental/_linker.py +++ b/cuda_core/cuda/core/experimental/_linker.py @@ -32,6 +32,7 @@ def _decide_nvjitlink_or_driver(): _driver_ver = handle_return(cuda.cuDriverGetVersion()) _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) try: + raise ImportError from cuda.bindings import nvjitlink as _nvjitlink from cuda.bindings._internal import nvjitlink as inner_nvjitlink except ImportError: diff --git a/cuda_core/tests/test_linker.py b/cuda_core/tests/test_linker.py index 54cd8cf4..e3120137 100644 --- a/cuda_core/tests/test_linker.py +++ b/cuda_core/tests/test_linker.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +from contextlib import contextmanager, nullcontext + import pytest from cuda.core.experimental import Linker, LinkerOptions, Program, _linker @@ -18,6 +20,23 @@ device_function_c = "__device__ int C(int a, int b) { return a + b; }" culink_backend = _linker._decide_nvjitlink_or_driver() +skip_options = nullcontext +if not culink_backend: + from cuda.bindings import nvjitlink + + @contextmanager + def skip_version_specific_linker_options(): + if culink_backend: + return + try: + yield + except nvjitlink.nvJitLinkError as e: + if e.status == nvjitlink.Result.ERROR_UNRECOGNIZED_OPTION: + pytest.skip("current nvjitlink version does not support the option provided") + else: + raise + + skip_options = skip_version_specific_linker_options @pytest.fixture(scope="function") @@ -72,9 +91,11 @@ def compile_ltoir_functions(init_cuda): ], ) def test_linker_init(compile_ptx_functions, options): - linker = Linker(*compile_ptx_functions, options=options) - object_code = linker.link("cubin") - assert isinstance(object_code, ObjectCode) + with skip_options(): + linker = Linker(*compile_ptx_functions, options=options) + + object_code = linker.link("cubin") + assert isinstance(object_code, ObjectCode) def test_linker_init_invalid_arch():