Skip to content

Commit

Permalink
CUDA Toolkit version + Jax incompatibility check (#166)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickjbrowning authored Jan 23, 2025
1 parent f7b5b71 commit e94a3bb
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
5 changes: 4 additions & 1 deletion sphericart-jax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
name = "sphericart-jax"
dynamic = ["version"]
requires-python = ">=3.9"
dependencies = ["jax >= 0.4.18"]
dependencies = [
"jax >= 0.4.18",
"packaging",
]

readme = "README.md"
license = {text = "Apache-2.0"}
Expand Down
61 changes: 60 additions & 1 deletion sphericart-jax/python/sphericart/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,77 @@
import jax
from packaging import version
import warnings

from .lib import sphericart_jax_cpu
from .spherical_harmonics import spherical_harmonics, solid_harmonics # noqa: F401


def get_minimum_cuda_version_for_jax(jax_version):
"""
Get the minimum required CUDA version for a specific JAX version.
Args:
jax_version (str): Installed JAX version, e.g., '0.4.11'.
Returns:
tuple: Minimum required CUDA version as (major, minor), e.g., (11, 8).
"""
# Define ranges of JAX versions and their corresponding minimum CUDA versions
version_ranges = [
(
version.parse("0.4.26"),
version.parse("999.999.999"),
(12, 1),
), # JAX 0.4.26 and later: CUDA 12.1+
(
version.parse("0.4.11"),
version.parse("0.4.25"),
(11, 8),
), # JAX 0.4.11 - 0.4.25: CUDA 11.8+
]

jax_ver = version.parse(jax_version)

# Find the appropriate CUDA version range
for start, end, cuda_version in version_ranges:
if start <= jax_ver <= end:
return cuda_version

raise ValueError(f"Unsupported JAX version: {jax_version}")


# register the operations to xla
for _name, _value in sphericart_jax_cpu.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="cpu")

has_sphericart_jax_cuda = False
try:
from .lib import sphericart_jax_cuda

has_sphericart_jax_cuda = True
# register the operations to xla
for _name, _value in sphericart_jax_cuda.registrations().items():
jax.lib.xla_client.register_custom_call_target(_name, _value, platform="gpu")

except ImportError:
has_sphericart_jax_cuda = False
pass

if has_sphericart_jax_cuda:
from .lib.sphericart_jax_cuda import get_cuda_runtime_version

# check the jaxlib version is suitable for the host cudatoolkit.
cuda_version = get_cuda_runtime_version()
cuda_version = (cuda_version["major"], cuda_version["minor"])
jax_version = jax.__version__
required_version = get_minimum_cuda_version_for_jax(jax_version)
if cuda_version < required_version:
warnings.warn(
"The installed CUDA Toolkit version is "
f"{cuda_version[0]}.{cuda_version[1]}, which "
f"is not compatible with the installed JAX version {jax_version}. "
"The minimum required CUDA Toolkit for your JAX version "
f"is {required_version[0]}.{required_version[1]}. "
"Please upgrade your CUDA Toolkit to meet the requirements, or ",
"downgrade JAX to a compatible version.",
stacklevel=2,
)
15 changes: 15 additions & 0 deletions sphericart-jax/src/sphericart_jax_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
#include <mutex>
#include <tuple>

#include "dynamic_cuda.hpp"
#include "sphericart_cuda.hpp"
#include "sphericart/pybind11_kernel_helpers.hpp"

using namespace pybind11::literals;

struct SphDescriptor {
std::int64_t n_samples;
std::int64_t lmax;
Expand Down Expand Up @@ -115,11 +118,23 @@ pybind11::dict Registrations() {
return dict;
}

std::pair<int, int> getCUDARuntimeVersion() {
int version;
CUDART_SAFE_CALL(CUDART_INSTANCE.cudaRuntimeGetVersion(&version));
int major = version / 1000;
int minor = (version % 1000) / 10;
return {major, minor};
}

PYBIND11_MODULE(sphericart_jax_cuda, m) {
m.def("registrations", &Registrations);
m.def("build_sph_descriptor", [](std::int64_t n_samples, std::int64_t lmax) {
return PackDescriptor(SphDescriptor{n_samples, lmax});
});
m.def("get_cuda_runtime_version", []() {
auto [major, minor] = getCUDARuntimeVersion();
return pybind11::dict("major"_a = major, "minor"_a = minor);
});
}

} // namespace cuda
Expand Down
4 changes: 4 additions & 0 deletions sphericart/include/dynamic_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class CUDART {
using cudaDeviceSynchronize_t = cudaError_t (*)(void);
using cudaPointerGetAttributes_t = cudaError_t (*)(cudaPointerAttributes*, const void*);
using cudaFree_t = cudaError_t (*)(void*);
using cudaRuntimeGetVersion_t = cudaError_t (*)(int*);

cudaGetDeviceCount_t cudaGetDeviceCount;
cudaGetDevice_t cudaGetDevice;
Expand All @@ -103,6 +104,7 @@ class CUDART {
cudaDeviceSynchronize_t cudaDeviceSynchronize;
cudaPointerGetAttributes_t cudaPointerGetAttributes;
cudaFree_t cudaFree;
cudaRuntimeGetVersion_t cudaRuntimeGetVersion;

CUDART() {
#ifdef __linux__
Expand All @@ -124,6 +126,8 @@ class CUDART {
cudaPointerGetAttributes =
load<cudaPointerGetAttributes_t>(cudartHandle, "cudaPointerGetAttributes");
cudaFree = load<cudaFree_t>(cudartHandle, "cudaFree");
cudaRuntimeGetVersion =
load<cudaRuntimeGetVersion_t>(cudartHandle, "cudaRuntimeGetVersion");
}
}

Expand Down

0 comments on commit e94a3bb

Please sign in to comment.