Skip to content

Commit

Permalink
Move 'import cupy' to an _import module.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Dec 26, 2023
1 parent 2e15d62 commit cbbea0e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/ragged/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
RegularArray,
)

from . import _import
from ._typing import Device, Dtype, NestedSequence, Shape, SupportsDLPack


Expand Down Expand Up @@ -96,8 +97,7 @@ def __init__(
if isinstance(self._impl, ak.Array) and device != ak.backend(self._impl):
self._impl = ak.to_backend(self._impl, device)
elif isinstance(self._impl, np.ndarray) and device == "cuda":
import cupy as cp # pylint: disable=C0415

cp = _import.cupy()
self._impl = cp.array(self._impl.item())

def __str__(self) -> str:
Expand Down
22 changes: 22 additions & 0 deletions src/ragged/common/_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/ragged/blob/main/LICENSE

from __future__ import annotations

from typing import Any


def cupy() -> Any:
try:
import cupy as cp # pylint: disable=C0415

return cp
except ModuleNotFoundError as err:
error_message = """to use the "cuda" backend, you must install cupy:
pip install cupy
or
conda install -c conda-forge cupy
"""
raise ModuleNotFoundError(error_message) from err

0 comments on commit cbbea0e

Please sign in to comment.