-
Notifications
You must be signed in to change notification settings - Fork 39
/
config.py
168 lines (156 loc) · 5.54 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import argparse
import sys
import torch
from multiprocessing import cpu_count
def use_fp32_config():
for config_file in [
"32k.json",
"40k.json",
"48k.json",
"48k_v2.json",
"32k_v2.json",
]:
with open(f"configs/{config_file}", "r") as f:
strr = f.read().replace("true", "false")
with open(f"configs/{config_file}", "w") as f:
f.write(strr)
class Config:
def __init__(self):
self.device = "cuda:0"
self.is_half = True
self.n_cpu = 0
self.gpu_name = None
self.gpu_mem = None
self.has_gpu = torch.cuda.is_available()
(
self.python_cmd,
self.listen_port,
self.iscolab,
self.noparallel,
self.noautoopen,
self.dml,
) = self.arg_parse()
self.instead = ""
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
@staticmethod
def arg_parse() -> tuple:
exe = sys.executable or "python"
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=7865, help="Listen port")
parser.add_argument("--pycmd", type=str, default=exe, help="Python command")
parser.add_argument("--colab", action="store_true", help="Launch in colab")
parser.add_argument(
"--noparallel", action="store_true", help="Disable parallel processing"
)
parser.add_argument(
"--noautoopen",
action="store_true",
help="Do not open in browser automatically",
)
parser.add_argument(
"--dml",
action="store_true",
help="torch_dml",
)
cmd_opts, unknown = parser.parse_known_args() # allows import to jupyter notebook
print(f"unknown args: {unknown}")
# cmd_opts.port = cmd_opts.port if 0 <= cmd_opts.port <= 65535 else 7865
return (
cmd_opts.pycmd,
cmd_opts.port,
cmd_opts.colab,
cmd_opts.noparallel,
cmd_opts.noautoopen,
cmd_opts.dml,
)
# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
# check `getattr` and try it for compatibility
@staticmethod
def has_mps() -> bool:
if not torch.backends.mps.is_available():
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
def device_config(self) -> tuple:
if torch.cuda.is_available():
i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.cuda.get_device_name(i_device)
if (
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
or "P40" in self.gpu_name.upper()
or "1060" in self.gpu_name
or "1070" in self.gpu_name
or "1080" in self.gpu_name
):
print("Found GPU", self.gpu_name, ", force to fp32")
self.is_half = False
use_fp32_config()
else:
print("Found GPU", self.gpu_name)
self.gpu_mem = int(
torch.cuda.get_device_properties(i_device).total_memory
/ 1024
/ 1024
/ 1024
# + 0.4
)
elif self.has_mps():
print("No supported Nvidia GPU found")
self.device = self.instead = "mps"
self.is_half = False
use_fp32_config()
else:
print("No supported Nvidia GPU found")
self.device = self.instead = "cpu"
self.is_half = False
use_fp32_config()
if self.n_cpu == 0:
self.n_cpu = cpu_count()
if self.is_half:
# 6G显存配置
x_pad = 3
x_query = 10
x_center = 60
x_max = 64
else:
# 5G显存配置
x_pad = 1
x_query = 6
x_center = 38
x_max = 41
if self.gpu_mem is not None and self.gpu_mem <= 4:
x_pad = 1
x_query = 5
x_center = 30
x_max = 32
if self.dml:
print("use DirectML instead")
if(os.path.exists("runtime\Lib\site-packages\onnxruntime\capi\DirectML.dll")==False):
try:
os.rename("runtime\Lib\site-packages\onnxruntime", "runtime\Lib\site-packages\onnxruntime-cuda")
except:
pass
try:
os.rename("runtime\Lib\site-packages\onnxruntime-dml", "runtime\Lib\site-packages\onnxruntime")
except:
pass
import torch_directml
self.device = torch_directml.device(torch_directml.default_device())
self.is_half = False
else:
if self.instead:
print(f"use {self.instead} instead")
if(os.path.exists("runtime\Lib\site-packages\onnxruntime\capi\onnxruntime_providers_cuda.dll")==False):
try:
os.rename("runtime\Lib\site-packages\onnxruntime", "runtime\Lib\site-packages\onnxruntime-dml")
except:
pass
try:
os.rename("runtime\Lib\site-packages\onnxruntime-cuda", "runtime\Lib\site-packages\onnxruntime")
except:
pass
return x_pad, x_query, x_center, x_max