Skip to content

Commit b4e9528

Browse files
authored
[Core] Streamline stream termination in AsyncLLMEngine (vllm-project#7336)
1 parent 57b7be0 commit b4e9528

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

tests/async_engine/test_request_tracker.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ async def test_request_tracker():
4747
assert tracker.new_requests_event.is_set()
4848
await tracker.wait_for_new_requests()
4949
new, aborted = tracker.get_new_and_aborted_requests()
50-
assert len(aborted) == 1
51-
assert "4" in aborted
50+
# aborted new requests will cancel each other out -
51+
# there's no need for them to propagate into the
52+
# engine
53+
assert not aborted
5254
assert not new
5355
assert stream_4.finished
5456

vllm/engine/async_llm_engine.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,14 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
8585
return
8686
self._queue.put_nowait(item)
8787

88-
def finish(self, cancelled: bool = False) -> None:
88+
def finish(
89+
self,
90+
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
91+
) -> None:
8992
if not self._finished:
9093
self._finished = True
9194
self._queue.put_nowait(
92-
asyncio.CancelledError if cancelled else STOP_ITERATION)
95+
exception if exception is not None else STOP_ITERATION)
9396

9497
@property
9598
def finished(self) -> bool:
@@ -133,14 +136,12 @@ def propagate_exception(self,
133136
"""Propagate an exception to request streams
134137
(all if request_id is None)."""
135138
if request_id is not None:
136-
self._request_streams[request_id].put(exc)
137-
self.abort_request(request_id)
139+
self.abort_request(request_id, exception=exc)
138140
else:
139-
# NB: list() used here because self.abort_request pops the stream
141+
# NB: tuple() used here because self.abort_request pops the stream
140142
# out of self._request_streams, so we can't iterate on it directly
141-
for rid, stream in list(self._request_streams.items()):
142-
stream.put(exc)
143-
self.abort_request(rid)
143+
for rid in tuple(self._request_streams.keys()):
144+
self.abort_request(rid, exception=exc)
144145

145146
def process_request_output(self,
146147
request_output: Union[RequestOutput,
@@ -167,14 +168,13 @@ def process_request_output(self,
167168

168169
def process_exception(self,
169170
request_id: str,
170-
exception: Exception,
171+
exception: BaseException,
171172
*,
172173
verbose: bool = False) -> None:
173174
"""Propagate an exception from the engine."""
174-
self._request_streams[request_id].put(exception)
175175
if verbose:
176176
logger.info("Finished request %s.", request_id)
177-
self.abort_request(request_id)
177+
self.abort_request(request_id, exception=exception)
178178

179179
def add_request(self,
180180
request_id: str,
@@ -203,7 +203,8 @@ def add_request(self,
203203
def abort_request(self,
204204
request_id: str,
205205
*,
206-
cancelled: bool = False,
206+
exception: Optional[Union[BaseException,
207+
Type[BaseException]]] = None,
207208
verbose: bool = False) -> None:
208209
"""Abort a request during next background loop iteration."""
209210
if verbose:
@@ -213,7 +214,7 @@ def abort_request(self,
213214

214215
stream = self._request_streams.pop(request_id, None)
215216
if stream is not None:
216-
stream.finish(cancelled=cancelled)
217+
stream.finish(exception=exception)
217218

218219
def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
219220
"""Get the new requests and finished requests to be
@@ -227,12 +228,14 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
227228

228229
while not self._new_requests.empty():
229230
stream, new_request = self._new_requests.get_nowait()
230-
if stream.request_id in finished_requests:
231+
request_id = stream.request_id
232+
if request_id in finished_requests:
231233
# The request has already been aborted.
232-
stream.finish(cancelled=True)
233-
continue
234-
self._request_streams[stream.request_id] = stream
235-
new_requests.append(new_request)
234+
stream.finish(asyncio.CancelledError)
235+
finished_requests.discard(request_id)
236+
else:
237+
self._request_streams[request_id] = stream
238+
new_requests.append(new_request)
236239

237240
return new_requests, finished_requests
238241

@@ -1015,7 +1018,7 @@ def _abort(self, request_id: str) -> None:
10151018
request_id: The unique id of the request.
10161019
"""
10171020
self._request_tracker.abort_request(request_id,
1018-
cancelled=True,
1021+
exception=asyncio.CancelledError,
10191022
verbose=self.log_requests)
10201023

10211024
async def get_model_config(self) -> ModelConfig:

0 commit comments

Comments
 (0)