-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
186 lines (159 loc) · 5.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import os
import shutil
import subprocess
import sys
import time
import typing as tp
from pathlib import Path
import requests
def maybe_download_tarball_with_pget(
url: str,
dest: str,
):
"""
Downloads a tarball from url and decompresses to dest if dest does not exist. Remote path is constructed
by concatenating remote_path and remote_filename. If remote_path is None, files are not downloaded.
Args:
url (str): URL to the tarball
dest (str): Path to the directory where the tarball should be decompressed
Returns:
path (str): Path to the directory where files were downloaded
"""
try:
Path("/weights").mkdir(exist_ok=True)
first_dest = "/weights/triton"
except PermissionError:
print("/weights doesn't exist, and we couldn't create it")
first_dest = dest
# if dest exists and is not empty, return
if os.path.exists(first_dest) and os.listdir(first_dest):
print(f"Files already present in `{first_dest}`, nothing will be downloaded.")
if first_dest != dest:
try:
os.symlink(first_dest, dest)
except FileExistsError:
print(f"Ignoring existing file at {dest}")
return dest
# if dest exists but is empty, remove it so we can pull with pget
if os.path.exists(first_dest):
shutil.rmtree(first_dest)
print("Downloading model assets...")
command = ["pget", url, first_dest, "-x"]
subprocess.check_call(command, close_fds=True)
if first_dest != dest:
os.symlink(first_dest, dest)
return dest
class TritonHandler:
def __init__(
self,
world_size=1,
tritonserver="/opt/tritonserver/bin/tritonserver",
grpc_port="8001",
http_port="8000",
metrics_port="8002",
force=False,
log=False,
log_file="triton_log.txt",
model_repo=None,
):
if model_repo is None:
model_repo = str(Path(__file__).parent.absolute()) + "/../all_models/gpt"
self.world_size = world_size
self.tritonserver = tritonserver
self.grpc_port = grpc_port
self.http_port = http_port
self.metrics_port = metrics_port
self.force = force
self.log = log
self.log_file = log_file
self.model_repo = model_repo
def get_cmd(self):
cmd = ["mpirun", "--allow-run-as-root"]
for i in range(self.world_size):
cmd += ["-n", "1", self.tritonserver]
if self.log and (i == 0):
cmd += ["--log-verbose=3", f"--log-file={self.log_file}"]
cmd += [
f"--grpc-port={self.grpc_port}",
f"--http-port={self.http_port}",
f"--metrics-port={self.metrics_port}",
f"--model-repository={self.model_repo}",
"--disable-auto-complete-config",
f"--backend-config=python,shm-region-prefix-name=prefix{i}_",
":",
]
return cmd
def start(self):
res = subprocess.run(
["pgrep", "-r", "R", "tritonserver"], capture_output=True, encoding="utf-8"
)
if res.stdout:
pids = res.stdout.replace("\n", " ").rstrip()
msg = f"tritonserver process(es) already found with PID(s): {pids}.\n\tUse `kill {pids}` to stop them."
if self.force:
print(msg, file=sys.stderr)
else:
raise RuntimeError(msg + " Or use --force.")
cmd = self.get_cmd()
process = subprocess.Popen(cmd)
try:
# Exponential backoff
max_retries = 10
delay = 0.01 # initial delay
for i in range(max_retries):
try:
response = requests.get(f"http://localhost:{self.http_port}")
if response.status_code == 200:
print("Server started successfully.")
return True
except requests.exceptions.ConnectionError:
pass
time.sleep(delay)
delay *= 2 # double the delay
stdout, stderr = process.communicate()
error_message = stderr.decode("utf-8") if stderr else ""
raise RuntimeError(f"Server failed to start: {error_message}")
except RuntimeError as e:
process.terminate()
raise e
class StreamingTokenStopSequenceHandler:
def __init__(
self,
stop_sequences: tp.List[str] = None,
):
self.stop_sequences = stop_sequences or []
self.stop_sequence_fulfilled = False
self.cache = []
def stop(self):
self.stop_sequence_fulfilled = True
def process(self, token):
if self.stop_sequence_fulfilled:
raise RuntimeError(
"Stop sequence has been fulfilled, but server is still yielding tokens"
)
self.cache.append(token)
output = "".join(self.cache)
partial_match = False
for stop_sequence in self.stop_sequences:
if stop_sequence == output:
self.cache.clear()
self.stop()
return None
elif stop_sequence.startswith(output):
partial_match = True
if partial_match:
return None
else:
self.cache.clear()
return output
def __call__(self, token):
if self.stop_sequences:
return self.process(token)
else:
return token
def finalize(self):
if self.cache:
final_output = "".join(self.cache)
self.cache.clear()
return final_output
return None