forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcuda.py
87 lines (79 loc) · 3.06 KB
/
cuda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import glob
import re
import ctypes.util
from subprocess import Popen, PIPE
from .env import IS_WINDOWS, IS_LINUX, IS_DARWIN, check_env_flag, check_negative_env_flag
LINUX_HOME = '/usr/local/cuda'
WINDOWS_HOME = glob.glob('C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*')
def find_nvcc():
if IS_WINDOWS:
proc = Popen(['where', 'nvcc.exe'], stdout=PIPE, stderr=PIPE)
else:
proc = Popen(['which', 'nvcc'], stdout=PIPE, stderr=PIPE)
out, err = proc.communicate()
out = out.decode().strip()
if len(out) > 0:
if IS_WINDOWS:
if out.find('\r\n') != -1:
out = out.split('\r\n')[0]
out = os.path.abspath(os.path.join(os.path.dirname(out), ".."))
out = out.replace('\\', '/')
out = str(out)
return os.path.dirname(out)
else:
return None
def find_cuda_version(cuda_home):
if cuda_home is None:
return None
if IS_WINDOWS:
candidate_names = [os.path.basename(cuda_home)]
else:
# get CUDA lib folder
cuda_lib_dirs = ['lib64', 'lib']
for lib_dir in cuda_lib_dirs:
cuda_lib_path = os.path.join(cuda_home, lib_dir)
if os.path.exists(cuda_lib_path):
break
# get a list of candidates for the version number
# which are files containing cudart
candidate_names = list(glob.glob(os.path.join(cuda_lib_path, '*cudart*')))
candidate_names = [os.path.basename(c) for c in candidate_names]
# suppose version is MAJOR.MINOR.PATCH, all numbers
version_regex = re.compile(r'[0-9]+\.[0-9]+\.[0-9]+')
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
if len(candidates) > 0:
# normally only one will be retrieved, take the first result
return candidates[0]
# if no candidates were found, try MAJOR.MINOR
version_regex = re.compile(r'[0-9]+\.[0-9]+')
candidates = [c.group() for c in map(version_regex.search, candidate_names) if c]
if len(candidates) > 0:
return candidates[0]
if check_negative_env_flag('USE_CUDA') or check_env_flag('USE_ROCM'):
USE_CUDA = False
CUDA_HOME = None
CUDA_VERSION = None
else:
if IS_LINUX or IS_DARWIN:
CUDA_HOME = os.getenv('CUDA_HOME', LINUX_HOME)
else:
CUDA_HOME = os.getenv('CUDA_PATH', '').replace('\\', '/')
if CUDA_HOME == '' and len(WINDOWS_HOME) > 0:
CUDA_HOME = WINDOWS_HOME[0].replace('\\', '/')
if not os.path.exists(CUDA_HOME):
# We use nvcc path on Linux and cudart path on macOS
if IS_LINUX or IS_WINDOWS:
cuda_path = find_nvcc()
else:
cudart_path = ctypes.util.find_library('cudart')
if cudart_path is not None:
cuda_path = os.path.dirname(cudart_path)
else:
cuda_path = None
if cuda_path is not None:
CUDA_HOME = os.path.dirname(cuda_path)
else:
CUDA_HOME = None
CUDA_VERSION = find_cuda_version(CUDA_HOME)
USE_CUDA = CUDA_HOME is not None