diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a63d48016b83c..cef9974b44a53 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,15 +1,14 @@ import asyncio import time from functools import partial -from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, - Union, AsyncIterator) +from typing import (Any, AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) -from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_cluster, ray from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -77,7 +76,7 @@ def __init__(self) -> None: self._request_streams: Dict[str, AsyncStream] = {} self._finished_requests: asyncio.Queue[str] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[AsyncStream, - dict]] = asyncio.Queue() + dict]] = asyncio.Queue() self.new_requests_event = None def __contains__(self, item): @@ -135,7 +134,7 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: self._finished_requests.put_nowait(request_id) if request_id not in self._request_streams or self._request_streams[ - request_id].finished: + request_id].finished: # The request has already finished or been aborted. return @@ -203,11 +202,11 @@ async def step_async(self) -> List[RequestOutput]: return self._process_model_outputs(output, scheduler_outputs) async def encode_request_async( - self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, - lora_request: Optional[LoRARequest] = None, + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, ): if prompt_token_ids is None: assert prompt is not None @@ -218,14 +217,14 @@ async def encode_request_async( return prompt_token_ids async def add_request_async( - self, - request_id: str, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -249,12 +248,12 @@ async def add_request_async( ) async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[List[Any]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, ) -> Any: """Runs the given method on all workers.""" coros = [] @@ -317,6 +316,9 @@ def __init__(self, self.log_requests = log_requests self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) + # jimpang: for lora + self.lora_names_map = {} + self.last_lora_id = 1 self.background_loop = None # We need to keep a reference to unshielded @@ -410,14 +412,14 @@ async def run_engine_loop(self): await asyncio.sleep(0) async def add_request( - self, - request_id: str, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -427,7 +429,7 @@ async def add_request( shortened_prompt = shortened_prompt[:self.max_log_len] if shortened_token_ids is not None: shortened_token_ids = shortened_token_ids[:self. - max_log_len] + max_log_len] logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"prefix_pos: {prefix_pos}," @@ -473,13 +475,13 @@ async def add_request( return stream async def generate( - self, - prompt: Optional[str], - sampling_params: SamplingParams, - request_id: str, - prompt_token_ids: Optional[List[int]] = None, - lora_request: Optional[LoRARequest] = None, - prefix_pos: Optional[int] = None, + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -552,6 +554,15 @@ async def generate( # This should not be used for logging, as it is monotonic time. arrival_time = time.monotonic() + # jimpang: process lora id + if lora_request: + if lora_request.lora_name in self.lora_names_map: + lora_request.lora_int_id = self.lora_names_map[lora_request.lora_name] + else: + self.last_lora_id = self.last_lora_id + 1 + lora_request.lora_int_id = self.last_lora_id + self.lora_names_map[lora_request.lora_name] = lora_request.lora_int_id + try: stream = await self.add_request( request_id, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f7b8d258fae4c..32ed0e5699566 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -2,12 +2,13 @@ import json from typing import AsyncGenerator +import uvicorn from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, Response, StreamingResponse -import uvicorn from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -35,22 +36,36 @@ async def generate(request: Request) -> Response: prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) + # lora + lora_id = request_dict.pop("lora_id", None) + lora_path = request_dict.pop("lora_path", None) + if lora_id is None or lora_path is None: + lora_request = None + else: + lora_request = LoRARequest(lora_name=lora_id, lora_int_id=0, lora_local_path=lora_path) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, - sampling_params, - request_id, - prefix_pos=prefix_pos) + # jimpang add + prompt_token_ids = None + if prompt and len(prompt) > 0: + first_element = prompt[0] + if isinstance(first_element, int): + prompt_token_ids = prompt + prompt = None + + results_generator = engine.generate( + prompt=prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: - prompt = request_output.prompt text_outputs = [ - prompt + output.text for output in request_output.outputs + output.text for output in request_output.outputs ] - ret = {"text": text_outputs} + output_tokens = [output.token_ids for output in request_output.outputs] + ret = {"text": text_outputs, "output_token_ids": output_tokens} yield (json.dumps(ret) + "\0").encode("utf-8") if stream: @@ -66,9 +81,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]: final_output = request_output assert final_output is not None - prompt = final_output.prompt - text_outputs = [prompt + output.text for output in final_output.outputs] - ret = {"text": text_outputs} + text_outputs = [output.text for output in final_output.outputs] + output_tokens = [output.token_ids for output in final_output.outputs] + ret = {"text": text_outputs, "output_token_ids": output_tokens} return JSONResponse(ret) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index bbbf4880ab81b..101dbbba4be97 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -19,11 +19,6 @@ class LoRARequest: lora_int_id: int lora_local_path: str - def __post_init__(self): - if self.lora_int_id < 1: - raise ValueError( - f"lora_int_id must be > 0, got {self.lora_int_id}") - def __eq__(self, value: object) -> bool: return isinstance( value, LoRARequest) and self.lora_int_id == value.lora_int_id