diff --git a/src/ragged/common/__init__.py b/src/ragged/common/__init__.py index 076ca20..9887ae8 100644 --- a/src/ragged/common/__init__.py +++ b/src/ragged/common/__init__.py @@ -14,6 +14,7 @@ RegularArray, ) +from . import _import from ._typing import Device, Dtype, NestedSequence, Shape, SupportsDLPack @@ -96,8 +97,7 @@ def __init__( if isinstance(self._impl, ak.Array) and device != ak.backend(self._impl): self._impl = ak.to_backend(self._impl, device) elif isinstance(self._impl, np.ndarray) and device == "cuda": - import cupy as cp # pylint: disable=C0415 - + cp = _import.cupy() self._impl = cp.array(self._impl.item()) def __str__(self) -> str: diff --git a/src/ragged/common/_import.py b/src/ragged/common/_import.py new file mode 100644 index 0000000..05c73ef --- /dev/null +++ b/src/ragged/common/_import.py @@ -0,0 +1,22 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE + +from __future__ import annotations + +from typing import Any + + +def cupy() -> Any: + try: + import cupy as cp # pylint: disable=C0415 + + return cp + except ModuleNotFoundError as err: + error_message = """to use the "cuda" backend, you must install cupy: + + pip install cupy + +or + + conda install -c conda-forge cupy +""" + raise ModuleNotFoundError(error_message) from err