forked from hpcaitech/Open-Sora
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathinstall-check-pytorch23.py
96 lines (84 loc) · 2.63 KB
/
install-check-pytorch23.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import subprocess
def run_command(command):
try:
result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return result.stdout.decode('utf-8').strip(), result.stderr.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
return None, str(e)
def check_nvcc():
print("Checking nvcc version... ", end="")
output, error = run_command("nvcc --version")
if output and "release 12.1" in output:
print("OK")
return True
print("FAILED. Found: ", output)
return False
def check_python():
print("Checking Python version... ", end="")
output, error = run_command("python --version")
if output and "Python 3.10." in output:
print("OK")
return True
print("FAILED. Found: ", output)
return False
def check_pytorch():
print("Checking PyTorch version... ", end="")
output, error = run_command('python -c "import torch; print(torch.__version__)"')
if output and "2.3.0" in output:
print("OK")
return True
print("FAILED. Found: ", output)
return False
def check_cuda_version():
print("Checking CUDA version... ", end="")
output, error = run_command('python -c "import torch; print(torch.version.cuda)"')
if output and "12.1" in output:
print("OK")
return True
print("FAILED. Found: ", output)
return False
def check_apex():
print("Checking Apex... ", end="")
output, error = run_command('python -c "import apex"')
if error:
print("FAILED. Found: ", output)
return False
print("OK")
return True
def check_flash_attn():
print("Checking Flash Attention... ", end="")
output, error = run_command('python -c "import flash_attn"')
if error:
print("FAILED. Found: ", output)
return False
print("OK")
return True
def check_xformers():
print("Checking xFormers... ", end="")
output, error = run_command('python -m xformers.info')
if output and "xFormers" in output:
print("OK")
return True
print("FAILED. Found: ", output)
return False
def main():
print("Starting environment check...\n")
checks = [
check_nvcc,
check_python,
check_pytorch,
check_cuda_version,
check_apex,
check_flash_attn,
check_xformers,
]
all_checks_passed = True
for check in checks:
if not check():
all_checks_passed = False
if all_checks_passed:
print("\nSUCCESS: All checks passed!")
else:
print("\nFAILED: Some checks did not pass.")
if __name__ == "__main__":
main()