From 759e1ddf63f5963dc10f6d5e75a64f3a0f2a0395 Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 11 Oct 2023 10:22:19 +0800 Subject: [PATCH] make IPv6 compatible, safe run for coroutine interrupting (#487) * make IPv6 compatible, safe run for coroutine interrupting * instance_id -> session_id and fix api_client.py * update doc * remove useless faq * safe ip mapping * update app.py * remove print * update doc --- benchmark/profile_restful_api.py | 6 +- docs/en/restful_api.md | 13 ++-- docs/zh_cn/restful_api.md | 13 ++-- lmdeploy/serve/async_engine.py | 94 +++++++++++++++++------------ lmdeploy/serve/gradio/app.py | 23 +++---- lmdeploy/serve/openai/api_client.py | 16 +++-- lmdeploy/serve/openai/api_server.py | 70 ++++++++++----------- lmdeploy/serve/openai/protocol.py | 9 ++- 8 files changed, 137 insertions(+), 107 deletions(-) diff --git a/benchmark/profile_restful_api.py b/benchmark/profile_restful_api.py index ff1db7b4b5..ed922bfd7a 100644 --- a/benchmark/profile_restful_api.py +++ b/benchmark/profile_restful_api.py @@ -14,7 +14,7 @@ def get_streaming_response(prompt: str, api_url: str, - instance_id: int, + session_id: int, request_output_len: int, stream: bool = True, sequence_start: bool = True, @@ -24,7 +24,7 @@ def get_streaming_response(prompt: str, pload = { 'prompt': prompt, 'stream': stream, - 'instance_id': instance_id, + 'session_id': session_id, 'request_output_len': request_output_len, 'sequence_start': sequence_start, 'sequence_end': sequence_end, @@ -36,7 +36,7 @@ def get_streaming_response(prompt: str, stream=stream) for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, - delimiter=b'\0'): + delimiter=b'\n'): if chunk: data = json.loads(chunk.decode('utf-8')) output = data['text'] diff --git a/docs/en/restful_api.md b/docs/en/restful_api.md index c5a4a0de07..cb70e26375 100644 --- a/docs/en/restful_api.md +++ b/docs/en/restful_api.md @@ -22,7 +22,7 @@ from typing import Iterable, List def get_streaming_response(prompt: str, api_url: str, - instance_id: int, + session_id: int, request_output_len: int, stream: bool = True, sequence_start: bool = True, @@ -32,7 +32,7 @@ def get_streaming_response(prompt: str, pload = { 'prompt': prompt, 'stream': stream, - 'instance_id': instance_id, + 'session_id': session_id, 'request_output_len': request_output_len, 'sequence_start': sequence_start, 'sequence_end': sequence_end, @@ -41,7 +41,7 @@ def get_streaming_response(prompt: str, response = requests.post( api_url, headers=headers, json=pload, stream=stream) for chunk in response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\0'): + chunk_size=8192, decode_unicode=False, delimiter=b'\n'): if chunk: data = json.loads(chunk.decode('utf-8')) output = data['text'] @@ -91,7 +91,7 @@ curl http://{server_ip}:{server_port}/generate \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello! How are you?", - "instance_id": 1, + "session_id": 1, "sequence_start": true, "sequence_end": true }' @@ -146,11 +146,10 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True 2. When OOM appeared at the server side, please reduce the number of `instance_num` when lanching the service. -3. When the request with the same `instance_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards. +3. When the request with the same `session_id` to `generate` got a empty return value and a negative `tokens`, please consider setting `sequence_start=false` for the second question and the same for the afterwards. 4. Requests were previously being handled sequentially rather than concurrently. To resolve this issue, - - kindly provide unique instance_id values when calling the `generate` API or else your requests may be associated with client IP addresses - - additionally, setting `stream=true` enables processing multiple requests simultaneously + - kindly provide unique session_id values when calling the `generate` API or else your requests may be associated with client IP addresses 5. Both `generate` api and `v1/chat/completions` upport engaging in multiple rounds of conversation, where input `prompt` or `messages` consists of either single strings or entire chat histories.These inputs are interpreted using multi-turn dialogue modes. However, ff you want to turn the mode of and manage the chat history in clients, please the parameter `sequence_end: true` when utilizing the `generate` function, or specify `renew_session: true` when making use of `v1/chat/completions` diff --git a/docs/zh_cn/restful_api.md b/docs/zh_cn/restful_api.md index ab35ead124..2b56fa0f26 100644 --- a/docs/zh_cn/restful_api.md +++ b/docs/zh_cn/restful_api.md @@ -24,7 +24,7 @@ from typing import Iterable, List def get_streaming_response(prompt: str, api_url: str, - instance_id: int, + session_id: int, request_output_len: int, stream: bool = True, sequence_start: bool = True, @@ -34,7 +34,7 @@ def get_streaming_response(prompt: str, pload = { 'prompt': prompt, 'stream': stream, - 'instance_id': instance_id, + 'session_id': session_id, 'request_output_len': request_output_len, 'sequence_start': sequence_start, 'sequence_end': sequence_end, @@ -43,7 +43,7 @@ def get_streaming_response(prompt: str, response = requests.post( api_url, headers=headers, json=pload, stream=stream) for chunk in response.iter_lines( - chunk_size=8192, decode_unicode=False, delimiter=b'\0'): + chunk_size=8192, decode_unicode=False, delimiter=b'\n'): if chunk: data = json.loads(chunk.decode('utf-8')) output = data['text'] @@ -93,7 +93,7 @@ curl http://{server_ip}:{server_port}/generate \ -H "Content-Type: application/json" \ -d '{ "prompt": "Hello! How are you?", - "instance_id": 1, + "session_id": 1, "sequence_start": true, "sequence_end": true }' @@ -148,12 +148,11 @@ python -m lmdeploy.serve.gradio.app restful_api_url server_ip --restful_api True 2. 当服务端显存 OOM 时,可以适当减小启动服务时的 `instance_num` 个数 -3. 当同一个 `instance_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false` +3. 当同一个 `session_id` 的请求给 `generate` 函数后,出现返回空字符串和负值的 `tokens`,应该是第二次问话没有设置 `sequence_start=false` 4. 如果感觉请求不是并发地被处理,而是一个一个地处理,请设置好以下参数: - - 不同的 instance_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。 - - 设置 `stream=true` 使模型在前向传播时可以允许其他请求进入被处理 + - 不同的 session_id 传入 `generate` api。否则,我们将自动绑定会话 id 为请求端的 ip 地址编号。 5. `generate` api 和 `v1/chat/completions` 均支持多轮对话。`messages` 或者 `prompt` 参数既可以是一个简单字符串表示用户的单词提问,也可以是一段对话历史。 两个 api 都是默认开启多伦对话的,如果你想关闭这个功能,然后在客户端管理会话记录,请设置 `sequence_end: true` 传入 `generate`,或者设置 diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py index 40f87ac0ea..e2c4b36840 100644 --- a/lmdeploy/serve/async_engine.py +++ b/lmdeploy/serve/async_engine.py @@ -47,14 +47,31 @@ def __init__(self, model_path, instance_num=32, tp=1) -> None: self.starts = [None] * instance_num self.steps = {} + def stop_session(self, session_id: int): + instance_id = session_id % self.instance_num + input_ids = self.tokenizer.encode('') + for outputs in self.generators[instance_id].stream_infer( + session_id, + input_ids, + request_output_len=0, + sequence_start=False, + sequence_end=False, + stop=True): + pass + self.available[instance_id] = True + @contextmanager - def safe_run(self, instance_id: int, stop: bool = False): + def safe_run(self, instance_id: int, session_id: Optional[int] = None): self.available[instance_id] = False - yield + try: + yield + except (Exception, asyncio.CancelledError) as e: # noqa + self.stop_session(session_id) self.available[instance_id] = True - async def get_embeddings(self, prompt): - prompt = self.model.get_prompt(prompt) + async def get_embeddings(self, prompt, do_prerpocess=False): + if do_prerpocess: + prompt = self.model.get_prompt(prompt) input_ids = self.tokenizer.encode(prompt) return input_ids @@ -68,7 +85,7 @@ async def get_generator(self, instance_id: int, stop: bool = False): async def generate( self, messages, - instance_id, + session_id, stream_response=True, sequence_start=True, sequence_end=False, @@ -85,7 +102,7 @@ async def generate( Args: messages (str | List): chat history or prompt - instance_id (int): actually request host ip + session_id (int): the session id stream_response (bool): whether return responses streamingly request_output_len (int): output token nums sequence_start (bool): indicator for starting a sequence @@ -102,8 +119,7 @@ async def generate( 1.0 means no penalty ignore_eos (bool): indicator for ignoring eos """ - session_id = instance_id - instance_id %= self.instance_num + instance_id = session_id % self.instance_num if str(session_id) not in self.steps: self.steps[str(session_id)] = 0 if step != 0: @@ -119,7 +135,7 @@ async def generate( finish_reason) else: generator = await self.get_generator(instance_id, stop) - with self.safe_run(instance_id): + with self.safe_run(instance_id, session_id): response_size = 0 async for outputs in generator.async_stream_infer( session_id=session_id, @@ -188,14 +204,14 @@ async def generate_openai( instance_id %= self.instance_num sequence_start = False generator = await self.get_generator(instance_id) - self.available[instance_id] = False if renew_session: # renew a session empty_input_ids = self.tokenizer.encode('') for outputs in generator.stream_infer(session_id=session_id, input_ids=[empty_input_ids], request_output_len=0, sequence_start=False, - sequence_end=True): + sequence_end=True, + stop=True): pass self.steps[str(session_id)] = 0 if str(session_id) not in self.steps: @@ -212,31 +228,31 @@ async def generate_openai( yield GenOut('', self.steps[str(session_id)], len(input_ids), 0, finish_reason) else: - response_size = 0 - async for outputs in generator.async_stream_infer( - session_id=session_id, - input_ids=[input_ids], - stream_output=stream_response, - request_output_len=request_output_len, - sequence_start=(sequence_start), - sequence_end=False, - step=self.steps[str(session_id)], - stop=stop, - top_k=top_k, - top_p=top_p, - temperature=temperature, - repetition_penalty=repetition_penalty, - ignore_eos=ignore_eos, - random_seed=seed if sequence_start else None): - res, tokens = outputs[0] - # decode res - response = self.tokenizer.decode(res.tolist(), - offset=response_size) - # response, history token len, input token len, gen token len - yield GenOut(response, self.steps[str(session_id)], - len(input_ids), tokens, finish_reason) - response_size = tokens - - # update step - self.steps[str(session_id)] += len(input_ids) + tokens - self.available[instance_id] = True + with self.safe_run(instance_id, session_id): + response_size = 0 + async for outputs in generator.async_stream_infer( + session_id=session_id, + input_ids=[input_ids], + stream_output=stream_response, + request_output_len=request_output_len, + sequence_start=(sequence_start), + sequence_end=False, + step=self.steps[str(session_id)], + stop=stop, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + ignore_eos=ignore_eos, + random_seed=seed if sequence_start else None): + res, tokens = outputs[0] + # decode res + response = self.tokenizer.decode(res.tolist(), + offset=response_size) + # response, history len, input len, generation len + yield GenOut(response, self.steps[str(session_id)], + len(input_ids), tokens, finish_reason) + response_size = tokens + + # update step + self.steps[str(session_id)] += len(input_ids) + tokens diff --git a/lmdeploy/serve/gradio/app.py b/lmdeploy/serve/gradio/app.py index 954a5bcd32..71db7a2749 100644 --- a/lmdeploy/serve/gradio/app.py +++ b/lmdeploy/serve/gradio/app.py @@ -12,6 +12,7 @@ from lmdeploy.serve.gradio.css import CSS from lmdeploy.serve.openai.api_client import (get_model_list, get_streaming_response) +from lmdeploy.serve.openai.api_server import ip2id from lmdeploy.serve.turbomind.chatbot import Chatbot THEME = gr.themes.Soft( @@ -37,7 +38,7 @@ def chat_stream(state_chatbot: Sequence, llama_chatbot: Chatbot, instruction = state_chatbot[-1][0] session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) bot_response = llama_chatbot.stream_infer( session_id, instruction, f'{session_id}-{len(state_chatbot)}') @@ -166,7 +167,7 @@ def chat_stream_restful( """ session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) bot_summarized_response = '' state_chatbot = state_chatbot + [(instruction, None)] @@ -176,7 +177,7 @@ def chat_stream_restful( for response, tokens, finish_reason in get_streaming_response( instruction, f'{InterFace.restful_api_url}/generate', - instance_id=session_id, + session_id=session_id, request_output_len=512, sequence_start=(len(state_chatbot) == 1), sequence_end=False): @@ -212,12 +213,12 @@ def reset_restful_func(instruction_txtbox: gr.Textbox, state_chatbot: gr.State, session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) # end the session for response, tokens, finish_reason in get_streaming_response( '', f'{InterFace.restful_api_url}/generate', - instance_id=session_id, + session_id=session_id, request_output_len=0, sequence_start=False, sequence_end=True): @@ -241,11 +242,11 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, """ session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) # end the session for out in get_streaming_response('', f'{InterFace.restful_api_url}/generate', - instance_id=session_id, + session_id=session_id, request_output_len=0, sequence_start=False, sequence_end=False, @@ -259,7 +260,7 @@ def cancel_restful_func(state_chatbot: gr.State, cancel_btn: gr.Button, messages.append(dict(role='assistant', content=qa[1])) for out in get_streaming_response(messages, f'{InterFace.restful_api_url}/generate', - instance_id=session_id, + session_id=session_id, request_output_len=0, sequence_start=True, sequence_end=False): @@ -346,7 +347,7 @@ async def chat_stream_local( """ session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) bot_summarized_response = '' state_chatbot = state_chatbot + [(instruction, None)] @@ -391,7 +392,7 @@ async def reset_local_func(instruction_txtbox: gr.Textbox, session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) # end the session async for out in InterFace.async_engine.generate('', session_id, @@ -419,7 +420,7 @@ async def cancel_local_func(state_chatbot: gr.State, cancel_btn: gr.Button, """ session_id = threading.current_thread().ident if request is not None: - session_id = int(request.kwargs['client']['host'].replace('.', '')) + session_id = ip2id(request.kwargs['client']['host']) # end the session async for out in InterFace.async_engine.generate('', session_id, diff --git a/lmdeploy/serve/openai/api_client.py b/lmdeploy/serve/openai/api_client.py index 449b8a294a..a8718331be 100644 --- a/lmdeploy/serve/openai/api_client.py +++ b/lmdeploy/serve/openai/api_client.py @@ -17,7 +17,7 @@ def get_model_list(api_url: str): def get_streaming_response(prompt: str, api_url: str, - instance_id: int, + session_id: int, request_output_len: int = 512, stream: bool = True, sequence_start: bool = True, @@ -28,7 +28,7 @@ def get_streaming_response(prompt: str, pload = { 'prompt': prompt, 'stream': stream, - 'instance_id': instance_id, + 'session_id': session_id, 'request_output_len': request_output_len, 'sequence_start': sequence_start, 'sequence_end': sequence_end, @@ -41,7 +41,7 @@ def get_streaming_response(prompt: str, stream=stream) for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, - delimiter=b'\0'): + delimiter=b'\n'): if chunk: data = json.loads(chunk.decode('utf-8')) output = data.pop('text', '') @@ -62,12 +62,20 @@ def main(restful_api_url: str, session_id: int = 0): while True: prompt = input_prompt() if prompt == 'exit': + for output, tokens, finish_reason in get_streaming_response( + '', + f'{restful_api_url}/generate', + session_id=session_id, + request_output_len=0, + sequence_start=(nth_round == 1), + sequence_end=True): + pass exit(0) else: for output, tokens, finish_reason in get_streaming_response( prompt, f'{restful_api_url}/generate', - instance_id=session_id, + session_id=session_id, request_output_len=512, sequence_start=(nth_round == 1), sequence_end=False): diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index 647c36609c..94271c4b9b 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import json import os import time from http import HTTPStatus @@ -7,7 +6,7 @@ import fire import uvicorn -from fastapi import BackgroundTasks, FastAPI, Request +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse @@ -16,8 +15,8 @@ ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingsRequest, - EmbeddingsResponse, ErrorResponse, GenerateRequest, ModelCard, ModelList, - ModelPermission, UsageInfo) + EmbeddingsResponse, ErrorResponse, GenerateRequest, GenerateResponse, + ModelCard, ModelList, ModelPermission, UsageInfo) os.environ['TM_LOG_LEVEL'] = 'ERROR' @@ -73,6 +72,16 @@ async def check_request(request) -> Optional[JSONResponse]: return ret +def ip2id(host_ip: str): + """Convert host ip address to session id.""" + if '.' in host_ip: # IPv4 + return int(host_ip.replace('.', '')[-8:]) + if ':' in host_ip: # IPv6 + return int(host_ip.replace(':', '')[-8:], 16) + print('Warning, could not get session id from ip, set it 0') + return 0 + + @app.post('/v1/chat/completions') async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Request = None): @@ -106,19 +115,18 @@ async def chat_completions_v1(request: ChatCompletionRequest, - presence_penalty (replaced with repetition_penalty) - frequency_penalty (replaced with repetition_penalty) """ - instance_id = int(raw_request.client.host.replace('.', '')) - + session_id = ip2id(raw_request.client.host) error_check_ret = await check_request(request) if error_check_ret is not None: return error_check_ret model_name = request.model - request_id = str(instance_id) + request_id = str(session_id) created_time = int(time.time()) result_generator = VariableInterface.async_engine.generate_openai( request.messages, - instance_id, + session_id, True, # always use stream to enable batching request.renew_session, request_output_len=request.max_tokens if request.max_tokens else 512, @@ -128,15 +136,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos) - async def abort_request() -> None: - async for _ in VariableInterface.async_engine.generate_openai( - request.messages, - instance_id, - True, - request.renew_session, - stop=True): - pass - def create_stream_response_json( index: int, text: str, @@ -181,12 +180,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # Streaming response if request.stream: - background_tasks = BackgroundTasks() - # Abort the request if the client disconnects. - background_tasks.add_task(abort_request) return StreamingResponse(completion_stream_generator(), - media_type='text/event-stream', - background=background_tasks) + media_type='text/event-stream') # Non-streaming response final_res = None @@ -194,7 +189,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. - await abort_request() + VariableInterface.async_engine.stop_session(session_id) return create_error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected') final_res = res @@ -257,7 +252,7 @@ async def generate(request: GenerateRequest, raw_request: Request = None): The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - - instance_id: determine which instance will be called. If not specified + - session_id: determine which instance will be called. If not specified with a value other than -1, using host ip directly. - sequence_start (bool): indicator for starting a sequence. - sequence_end (bool): indicator for ending a sequence @@ -275,13 +270,13 @@ async def generate(request: GenerateRequest, raw_request: Request = None): 1.0 means no penalty - ignore_eos (bool): indicator for ignoring eos """ - if request.instance_id == -1: - instance_id = int(raw_request.client.host.replace('.', '')) - request.instance_id = instance_id + if request.session_id == -1: + session_id = ip2id(raw_request.client.host) + request.session_id = session_id generation = VariableInterface.async_engine.generate( request.prompt, - request.instance_id, + request.session_id, stream_response=True, # always use stream to enable batching sequence_start=request.sequence_start, sequence_end=request.sequence_end, @@ -296,21 +291,26 @@ async def generate(request: GenerateRequest, raw_request: Request = None): # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: async for out in generation: - ret = { - 'text': out.response, - 'tokens': out.generate_token_len, - 'finish_reason': out.finish_reason - } - yield (json.dumps(ret) + '\0').encode('utf-8') + chunk = GenerateResponse(text=out.response, + tokens=out.generate_token_len, + finish_reason=out.finish_reason) + data = chunk.model_dump_json() + yield f'{data}\n' if request.stream: - return StreamingResponse(stream_results()) + return StreamingResponse(stream_results(), + media_type='text/event-stream') else: ret = {} text = '' tokens = 0 finish_reason = None async for out in generation: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + VariableInterface.async_engine.stop_session(session_id) + return create_error_response(HTTPStatus.BAD_REQUEST, + 'Client disconnected') text += out.response tokens = out.generate_token_len finish_reason = out.finish_reason diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 8f2919a1a5..b4eeadff74 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -190,7 +190,7 @@ class EmbeddingsResponse(BaseModel): class GenerateRequest(BaseModel): """Generate request.""" prompt: Union[str, List[Dict[str, str]]] - instance_id: int = -1 + session_id: int = -1 sequence_start: bool = True sequence_end: bool = False stream: bool = False @@ -201,3 +201,10 @@ class GenerateRequest(BaseModel): temperature: float = 0.8 repetition_penalty: float = 1.0 ignore_eos: bool = False + + +class GenerateResponse(BaseModel): + """Generate response.""" + text: str + tokens: int + finish_reason: Optional[Literal['stop', 'length']] = None