Skip to content

Commit

Permalink
feat: 1. api适配venus修改; 2. 适配multi-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
jimpang committed Feb 4, 2024
1 parent 1af090b commit a90d068
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 56 deletions.
91 changes: 51 additions & 40 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import asyncio
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)
from typing import (Any, AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type, Union)

from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams

Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__(self) -> None:
self._request_streams: Dict[str, AsyncStream] = {}
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
dict]] = asyncio.Queue()
dict]] = asyncio.Queue()
self.new_requests_event = None

def __contains__(self, item):
Expand Down Expand Up @@ -135,7 +134,7 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
self._finished_requests.put_nowait(request_id)

if request_id not in self._request_streams or self._request_streams[
request_id].finished:
request_id].finished:
# The request has already finished or been aborted.
return

Expand Down Expand Up @@ -203,11 +202,11 @@ async def step_async(self) -> List[RequestOutput]:
return self._process_model_outputs(output, scheduler_outputs)

async def encode_request_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
Expand All @@ -218,14 +217,14 @@ async def encode_request_async(
return prompt_token_ids

async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
Expand All @@ -249,12 +248,12 @@ async def add_request_async(
)

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []
Expand Down Expand Up @@ -317,6 +316,9 @@ def __init__(self,
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
# jimpang: for lora
self.lora_names_map = {}
self.last_lora_id = 1

self.background_loop = None
# We need to keep a reference to unshielded
Expand Down Expand Up @@ -410,14 +412,14 @@ async def run_engine_loop(self):
await asyncio.sleep(0)

async def add_request(
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
self,
request_id: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
Expand All @@ -427,7 +429,7 @@ async def add_request(
shortened_prompt = shortened_prompt[:self.max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
Expand Down Expand Up @@ -473,13 +475,13 @@ async def add_request(
return stream

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
Expand Down Expand Up @@ -552,6 +554,15 @@ async def generate(
# This should not be used for logging, as it is monotonic time.
arrival_time = time.monotonic()

# jimpang: process lora id
if lora_request:
if lora_request.lora_name in self.lora_names_map:
lora_request.lora_int_id = self.lora_names_map[lora_request.lora_name]
else:
self.last_lora_id = self.last_lora_id + 1
lora_request.lora_int_id = self.last_lora_id
self.lora_names_map[lora_request.lora_name] = lora_request.lora_int_id

try:
stream = await self.add_request(
request_id,
Expand Down
37 changes: 26 additions & 11 deletions vllm/entrypoints/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import json
from typing import AsyncGenerator

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid

Expand Down Expand Up @@ -35,22 +36,36 @@ async def generate(request: Request) -> Response:
prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)
# lora
lora_id = request_dict.pop("lora_id", None)
lora_path = request_dict.pop("lora_path", None)
if lora_id is None or lora_path is None:
lora_request = None
else:
lora_request = LoRARequest(lora_name=lora_id, lora_int_id=0, lora_local_path=lora_path)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()

results_generator = engine.generate(prompt,
sampling_params,
request_id,
prefix_pos=prefix_pos)
# jimpang add
prompt_token_ids = None
if prompt and len(prompt) > 0:
first_element = prompt[0]
if isinstance(first_element, int):
prompt_token_ids = prompt
prompt = None

results_generator = engine.generate(
prompt=prompt, sampling_params=sampling_params, request_id=request_id, prompt_token_ids=prompt_token_ids,
lora_request=lora_request)

# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
output_tokens = [output.token_ids for output in request_output.outputs]
ret = {"text": text_outputs, "output_token_ids": output_tokens}
yield (json.dumps(ret) + "\0").encode("utf-8")

if stream:
Expand All @@ -66,9 +81,9 @@ async def stream_results() -> AsyncGenerator[bytes, None]:
final_output = request_output

assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
text_outputs = [output.text for output in final_output.outputs]
output_tokens = [output.token_ids for output in final_output.outputs]
ret = {"text": text_outputs, "output_token_ids": output_tokens}
return JSONResponse(ret)


Expand Down
5 changes: 0 additions & 5 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ class LoRARequest:
lora_int_id: int
lora_local_path: str

def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")

def __eq__(self, value: object) -> bool:
return isinstance(
value, LoRARequest) and self.lora_int_id == value.lora_int_id
Expand Down

0 comments on commit a90d068

Please sign in to comment.