Skip to content

Commit 6b245ff

Browse files
committed
also optimize for explicit dev id
1 parent c9fac0b commit 6b245ff

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from cuda.core.experimental._utils.cuda_utils import (
1414
ComputeCapability,
1515
CUDAError,
16+
_check_driver_error,
1617
driver,
1718
handle_return,
1819
precondition,
@@ -948,7 +949,7 @@ class Device:
948949

949950
__slots__ = ("_id", "_mr", "_has_inited", "_properties")
950951

951-
def __new__(cls, device_id=None):
952+
def __new__(cls, device_id: int = None):
952953
global _is_cuInit
953954
if _is_cuInit is False:
954955
with _lock:
@@ -960,14 +961,14 @@ def __new__(cls, device_id=None):
960961
err, dev = driver.cuCtxGetDevice()
961962
if err == 0:
962963
device_id = int(dev)
963-
else:
964+
elif err == 201: # CUDA_ERROR_INVALID_CONTEXT
964965
ctx = handle_return(driver.cuCtxGetCurrent())
965966
assert int(ctx) == 0
966967
device_id = 0 # cudart behavior
967-
else:
968-
total = handle_return(driver.cuDeviceGetCount())
969-
if not isinstance(device_id, int) or not (0 <= device_id < total):
970-
raise ValueError(f"device_id must be within [0, {total}), got {device_id}")
968+
else:
969+
_check_driver_error(err)
970+
elif device_id < 0:
971+
raise ValueError(f"device_id must be >= 0, got {device_id}")
971972

972973
# ensure Device is singleton
973974
try:
@@ -995,7 +996,10 @@ def __new__(cls, device_id=None):
995996
dev._properties = None
996997
devices.append(dev)
997998

998-
return devices[device_id]
999+
try:
1000+
return devices[device_id]
1001+
except IndexError:
1002+
raise ValueError(f"device_id must be within [0, {len(devices)}), got {device_id}") from None
9991003

10001004
def _check_context_initialized(self, *args, **kwargs):
10011005
if not self._has_inited:

0 commit comments

Comments
 (0)