diff --git a/sphericart-jax/pyproject.toml b/sphericart-jax/pyproject.toml index 77928308d..2036aa3f3 100644 --- a/sphericart-jax/pyproject.toml +++ b/sphericart-jax/pyproject.toml @@ -2,7 +2,10 @@ name = "sphericart-jax" dynamic = ["version"] requires-python = ">=3.9" -dependencies = ["jax >= 0.4.18"] +dependencies = [ + "jax >= 0.4.18", + "packaging", +] readme = "README.md" license = {text = "Apache-2.0"} diff --git a/sphericart-jax/python/sphericart/jax/__init__.py b/sphericart-jax/python/sphericart/jax/__init__.py index df1f0c216..cfcd8450c 100644 --- a/sphericart-jax/python/sphericart/jax/__init__.py +++ b/sphericart-jax/python/sphericart/jax/__init__.py @@ -1,18 +1,77 @@ import jax +from packaging import version +import warnings + from .lib import sphericart_jax_cpu from .spherical_harmonics import spherical_harmonics, solid_harmonics # noqa: F401 +def get_minimum_cuda_version_for_jax(jax_version): + """ + Get the minimum required CUDA version for a specific JAX version. + + Args: + jax_version (str): Installed JAX version, e.g., '0.4.11'. + + Returns: + tuple: Minimum required CUDA version as (major, minor), e.g., (11, 8). + """ + # Define ranges of JAX versions and their corresponding minimum CUDA versions + version_ranges = [ + ( + version.parse("0.4.26"), + version.parse("999.999.999"), + (12, 1), + ), # JAX 0.4.26 and later: CUDA 12.1+ + ( + version.parse("0.4.11"), + version.parse("0.4.25"), + (11, 8), + ), # JAX 0.4.11 - 0.4.25: CUDA 11.8+ + ] + + jax_ver = version.parse(jax_version) + + # Find the appropriate CUDA version range + for start, end, cuda_version in version_ranges: + if start <= jax_ver <= end: + return cuda_version + + raise ValueError(f"Unsupported JAX version: {jax_version}") + + # register the operations to xla for _name, _value in sphericart_jax_cpu.registrations().items(): jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu") +has_sphericart_jax_cuda = False try: from .lib import sphericart_jax_cuda + has_sphericart_jax_cuda = True # register the operations to xla for _name, _value in sphericart_jax_cuda.registrations().items(): jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu") - except ImportError: + has_sphericart_jax_cuda = False pass + +if has_sphericart_jax_cuda: + from .lib.sphericart_jax_cuda import get_cuda_runtime_version + + # check the jaxlib version is suitable for the host cudatoolkit. + cuda_version = get_cuda_runtime_version() + cuda_version = (cuda_version["major"], cuda_version["minor"]) + jax_version = jax.__version__ + required_version = get_minimum_cuda_version_for_jax(jax_version) + if cuda_version < required_version: + warnings.warn( + "The installed CUDA Toolkit version is " + f"{cuda_version[0]}.{cuda_version[1]}, which " + f"is not compatible with the installed JAX version {jax_version}. " + "The minimum required CUDA Toolkit for your JAX version " + f"is {required_version[0]}.{required_version[1]}. " + "Please upgrade your CUDA Toolkit to meet the requirements, or ", + "downgrade JAX to a compatible version.", + stacklevel=2, + ) diff --git a/sphericart-jax/src/sphericart_jax_cuda.cpp b/sphericart-jax/src/sphericart_jax_cuda.cpp index 23ba7f765..3fc041f98 100644 --- a/sphericart-jax/src/sphericart_jax_cuda.cpp +++ b/sphericart-jax/src/sphericart_jax_cuda.cpp @@ -8,9 +8,12 @@ #include #include +#include "dynamic_cuda.hpp" #include "sphericart_cuda.hpp" #include "sphericart/pybind11_kernel_helpers.hpp" +using namespace pybind11::literals; + struct SphDescriptor { std::int64_t n_samples; std::int64_t lmax; @@ -115,11 +118,23 @@ pybind11::dict Registrations() { return dict; } +std::pair getCUDARuntimeVersion() { + int version; + CUDART_SAFE_CALL(CUDART_INSTANCE.cudaRuntimeGetVersion(&version)); + int major = version / 1000; + int minor = (version % 1000) / 10; + return {major, minor}; +} + PYBIND11_MODULE(sphericart_jax_cuda, m) { m.def("registrations", &Registrations); m.def("build_sph_descriptor", [](std::int64_t n_samples, std::int64_t lmax) { return PackDescriptor(SphDescriptor{n_samples, lmax}); }); + m.def("get_cuda_runtime_version", []() { + auto [major, minor] = getCUDARuntimeVersion(); + return pybind11::dict("major"_a = major, "minor"_a = minor); + }); } } // namespace cuda diff --git a/sphericart/include/dynamic_cuda.hpp b/sphericart/include/dynamic_cuda.hpp index c6112cba3..486fc20e2 100644 --- a/sphericart/include/dynamic_cuda.hpp +++ b/sphericart/include/dynamic_cuda.hpp @@ -92,6 +92,7 @@ class CUDART { using cudaDeviceSynchronize_t = cudaError_t (*)(void); using cudaPointerGetAttributes_t = cudaError_t (*)(cudaPointerAttributes*, const void*); using cudaFree_t = cudaError_t (*)(void*); + using cudaRuntimeGetVersion_t = cudaError_t (*)(int*); cudaGetDeviceCount_t cudaGetDeviceCount; cudaGetDevice_t cudaGetDevice; @@ -103,6 +104,7 @@ class CUDART { cudaDeviceSynchronize_t cudaDeviceSynchronize; cudaPointerGetAttributes_t cudaPointerGetAttributes; cudaFree_t cudaFree; + cudaRuntimeGetVersion_t cudaRuntimeGetVersion; CUDART() { #ifdef __linux__ @@ -124,6 +126,8 @@ class CUDART { cudaPointerGetAttributes = load(cudartHandle, "cudaPointerGetAttributes"); cudaFree = load(cudartHandle, "cudaFree"); + cudaRuntimeGetVersion = + load(cudartHandle, "cudaRuntimeGetVersion"); } }