diff --git a/tests/lora/test_add_lora.py b/tests/lora/test_add_lora.py new file mode 100644 index 0000000000000..df8031cba687b --- /dev/null +++ b/tests/lora/test_add_lora.py @@ -0,0 +1,165 @@ +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import time +from pathlib import Path +from typing import List + +import pytest +from huggingface_hub import snapshot_download + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.inputs import TextPrompt +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.utils import merge_async_iterators + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +LORA_MODULE_DOWNLOAD_PATH = None # Populated by download_and_prepare_lora_module() #noqa +LORA_RANK = 8 +DEFAULT_MAX_LORAS = 16 * 3 + + +def download_and_prepare_lora_module(): + """ + Request submission is expensive when the LoRA adapters have their own + tokenizers. This is because, for each request with a new LoRA adapter ID, + the front-end loads the tokenizer from disk. + + In this test, as we are comparing request processing times, we want to + minimize any extra activity. To this effect, we download the LoRA + adapter and remove all the tokenizer files, so the engine will default + to the base model tokenizer. + """ + global LORA_MODULE_DOWNLOAD_PATH + + LORA_MODULE_HF_PATH = "yard1/llama-2-7b-sql-lora-test" + LORA_MODULE_DOWNLOAD_PATH = snapshot_download(repo_id=LORA_MODULE_HF_PATH) + + tokenizer_files = [ + 'added_tokens.json', 'tokenizer_config.json', 'tokenizer.json', + 'tokenizer.model' + ] + for tokenizer_file in tokenizer_files: + del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file + del_path.unlink() + + +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + +def get_lora_requests() -> List[LoRARequest]: + lora_requests: List[LoRARequest] = [ + LoRARequest(lora_name=f"{i}", + lora_int_id=i, + lora_path=LORA_MODULE_DOWNLOAD_PATH) + for i in range(1, DEFAULT_MAX_LORAS + 1) + ] + return lora_requests + + +async def requests_processing_time(llm, + lora_requests: List[LoRARequest]) -> float: + + sampling_params = SamplingParams(n=1, + temperature=0.0, + top_p=1.0, + ignore_eos=True, + max_tokens=1) + + generators = [] + start = time.perf_counter() + + for lora_request in lora_requests: + lora_int_id = lora_request.lora_int_id + generator = llm.generate( + prompt=TextPrompt(prompt=f"hello {lora_int_id}", + multi_modal_data=None), # type: ignore + sampling_params=sampling_params, + lora_request=lora_request, + request_id=f"test{lora_int_id}") + generators.append(generator) + + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + + end = time.perf_counter() + return end - start + + +@pytest.mark.asyncio +async def test_add_lora(): + """ + The add_lora function is used to pre-load some LoRA adapters into the + engine in anticipation of future requests using these adapters. To test + this functionality, we use the async engine to process some requests - We + do it twice, once with add_lora() pre-loading and once without. + + We measure the request processing time in both cases and expect the time + to be lesser in the case with add_lora() calls. + """ + + download_and_prepare_lora_module() + + lora_requests: List[LoRARequest] = get_lora_requests() + + max_loras = len(set([lr.lora_int_id for lr in lora_requests])) + # Create engine in eager-mode. Due to high max_loras, the CI can + # OOM during cuda-graph capture. + engine_args = AsyncEngineArgs( + model=MODEL_PATH, + enable_lora=True, + max_loras=max_loras, + max_lora_rank=LORA_RANK, + max_model_len=128, + gpu_memory_utilization=0.8, #avoid OOM + enforce_eager=True) + + # The run_with_both_engines_lora fixture sets up the `VLLM_USE_V1` + # environment variable. reload vllm.enging.async_llm_engine as + # vllm.engine.async_llm_engine.AsyncLLMEgnine changes depending on the + # env var. + import importlib + + import vllm.engine.async_llm_engine + importlib.reload(vllm.engine.async_llm_engine) + from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) + + # split lora_requests into 3 parts + part_size = len(lora_requests) // 3 + dummy_run_requests = lora_requests[:part_size] + warmup_run_requests = lora_requests[part_size:part_size * 2] + cold_run_requests = lora_requests[part_size * 2:] + + async with build_async_engine_client_from_engine_args(engine_args) as llm: + + # Dummy run - So any 1-time functionality like triton kernel compilation + # is complete here. + await requests_processing_time(llm, dummy_run_requests) + + # Run with warmup + for lr in warmup_run_requests: + await llm.add_lora(lr) + # Wait for the add_lora function to complete on the server side. + await asyncio.sleep(30) + time_with_add_lora = await requests_processing_time( + llm, warmup_run_requests) + + # Run without any warmup + time_cold_start = await requests_processing_time( + llm, cold_run_requests) + + print(f"time hot-start {time_with_add_lora} vs " + f"time cold-start {time_cold_start} ") + + assert time_with_add_lora < time_cold_start, ( + f"time_with_add_lora={time_with_add_lora}, " + f"time_cold_start={time_cold_start}" + "The engine request processing time with LoRA pre-loading " + "must be less than the version that does on-demand LoRA loading.") diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 782fdcee3805a..dee7102bb47bf 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -134,3 +134,4 @@ class EngineCoreRequestType(enum.Enum): ABORT = b'\x01' PROFILE = b'\x02' RESET_PREFIX_CACHE = b'\x03' + ADD_LORA = b'\x04' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index f19d2ed8bcb6c..a669c9f6267c3 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -361,6 +361,10 @@ async def stop_profile(self) -> None: async def reset_prefix_cache(self) -> None: await self.engine_core.reset_prefix_cache_async() + async def add_lora(self, lora_request: LoRARequest) -> None: + """Load a new LoRA adapter into the engine for future requests.""" + await self.engine_core.add_lora_async(lora_request) + @property def is_running(self) -> bool: return True @@ -376,7 +380,3 @@ def errored(self) -> bool: @property def dead_error(self) -> BaseException: return Exception() # TODO: implement - - async def add_lora(self, lora_request: LoRARequest) -> None: - """Load a new LoRA adapter into the engine for future requests.""" - raise NotImplementedError("LoRA not yet supported in V1") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 4642ac1778ed0..401f331d81d42 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -13,6 +13,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx @@ -146,6 +147,9 @@ def profile(self, is_start: bool = True): def reset_prefix_cache(self): self.scheduler.reset_prefix_cache() + def add_lora(self, lora_request: LoRARequest) -> None: + self.model_executor.add_lora(lora_request) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -262,12 +266,15 @@ def _handle_client_request(self, request_type: EngineCoreRequestType, self.reset_prefix_cache() elif request_type == EngineCoreRequestType.PROFILE: self.model_executor.profile(request) + elif request_type == EngineCoreRequestType.ADD_LORA: + self.model_executor.add_lora(request) def process_input_socket(self, input_path: str): """Input socket IO thread.""" # Msgpack serialization decoding. add_request_decoder = MsgpackDecoder(EngineCoreRequest) + add_lora_decoder = MsgpackDecoder(LoRARequest) generic_decoder = MsgpackDecoder() with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket: @@ -277,9 +284,14 @@ def process_input_socket(self, input_path: str): request_type = EngineCoreRequestType(bytes(type_frame.buffer)) # Deserialize the request data. - decoder = add_request_decoder if ( - request_type - == EngineCoreRequestType.ADD) else generic_decoder + decoder = None + if request_type == EngineCoreRequestType.ADD: + decoder = add_request_decoder + elif request_type == EngineCoreRequestType.ADD_LORA: + decoder = add_lora_decoder + else: + decoder = generic_decoder + request = decoder.decode(data_frame.buffer) # Push to input queue for core busy loop. diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b3de5cdc244f3..07176629e9491 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -12,6 +12,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, @@ -77,6 +78,9 @@ def reset_prefix_cache(self) -> None: def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError + def add_lora(self, lora_request: LoRARequest) -> None: + raise NotImplementedError + async def get_output_async(self) -> EngineCoreOutputs: raise NotImplementedError @@ -92,6 +96,9 @@ async def reset_prefix_cache_async(self) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None: raise NotImplementedError + async def add_lora_async(self, lora_request: LoRARequest) -> None: + raise NotImplementedError + class InprocClient(EngineCoreClient): """ @@ -125,6 +132,9 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: self.engine_core.reset_prefix_cache() + def add_lora(self, lora_request: LoRARequest) -> None: + self.engine_core.add_lora(lora_request) + class MPClient(EngineCoreClient): """ @@ -242,6 +252,9 @@ def profile(self, is_start: bool = True) -> None: def reset_prefix_cache(self) -> None: self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) + def add_lora(self, lora_request: LoRARequest) -> None: + self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -295,3 +308,6 @@ async def profile_async(self, is_start: bool = True) -> None: async def reset_prefix_cache_async(self) -> None: await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None) + + async def add_lora_async(self, lora_request: LoRARequest) -> None: + await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 8f2ffe5f16ff5..10154a752393d 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,6 +15,7 @@ init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform from vllm.utils import GiB_bytes @@ -234,6 +235,9 @@ def profile(self, is_start: bool = True): else: self.profiler.stop() + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + def check_health(self) -> None: # worker will always be healthy as long as it's running. return diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index e7501ad2ea168..053897da0aa71 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -127,3 +127,8 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig, # __exit__ code self.lora_manager.remove_all_adapters() + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) \ No newline at end of file