13
13
from cuda .core .experimental ._utils .cuda_utils import (
14
14
ComputeCapability ,
15
15
CUDAError ,
16
+ _check_driver_error ,
16
17
driver ,
17
18
handle_return ,
18
19
precondition ,
@@ -948,7 +949,7 @@ class Device:
948
949
949
950
__slots__ = ("_id" , "_mr" , "_has_inited" , "_properties" )
950
951
951
- def __new__ (cls , device_id = None ):
952
+ def __new__ (cls , device_id : int = None ):
952
953
global _is_cuInit
953
954
if _is_cuInit is False :
954
955
with _lock :
@@ -960,14 +961,14 @@ def __new__(cls, device_id=None):
960
961
err , dev = driver .cuCtxGetDevice ()
961
962
if err == 0 :
962
963
device_id = int (dev )
963
- else :
964
+ elif err == 201 : # CUDA_ERROR_INVALID_CONTEXT
964
965
ctx = handle_return (driver .cuCtxGetCurrent ())
965
966
assert int (ctx ) == 0
966
967
device_id = 0 # cudart behavior
967
- else :
968
- total = handle_return ( driver . cuDeviceGetCount () )
969
- if not isinstance ( device_id , int ) or not ( 0 <= device_id < total ) :
970
- raise ValueError (f"device_id must be within [0, { total } ) , got { device_id } " )
968
+ else :
969
+ _check_driver_error ( err )
970
+ elif device_id < 0 :
971
+ raise ValueError (f"device_id must be >= 0 , got { device_id } " )
971
972
972
973
# ensure Device is singleton
973
974
try :
@@ -995,7 +996,10 @@ def __new__(cls, device_id=None):
995
996
dev ._properties = None
996
997
devices .append (dev )
997
998
998
- return devices [device_id ]
999
+ try :
1000
+ return devices [device_id ]
1001
+ except IndexError :
1002
+ raise ValueError (f"device_id must be within [0, { len (devices )} ), got { device_id } " ) from None
999
1003
1000
1004
def _check_context_initialized (self , * args , ** kwargs ):
1001
1005
if not self ._has_inited :
0 commit comments