Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support batch size > 1 #80

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions medusa/inference/api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Example Python client for vllm.entrypoints.api_server"""

import argparse
import json
from typing import Iterable, List
import pdb
import requests


def clear_line(n: int = 1) -> None:
LINE_UP = '\033[1A'
LINE_CLEAR = '\x1b[2K'
for _ in range(n):
print(LINE_UP, end=LINE_CLEAR, flush=True)


def post_http_request(prompt,
api_url: str,
n: int = 1,
stream: bool = False) -> requests.Response:
headers = {"User-Agent": "Test Client"}
pload = {
"prompt":prompt,
"max_tokens":150
}
print(pload)
response = requests.post(api_url, headers=headers, json=pload, stream=True)
return response


def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
for chunk in response.iter_lines(chunk_size=8192,
decode_unicode=False,
delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output


def get_response(response: requests.Response) -> List[str]:
print(response.content)
data = json.loads(response.content)
output = data["text"]
return output

def add_prefix(prompt):
prompt_ = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {prompt} ASSISTANT:"
return prompt_

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--n", type=int, default=1)
parser.add_argument("--prompt", type=str, default="San Francisco is a")
parser.add_argument("--stream", action="store_true")
args = parser.parse_args()
prompt = args.prompt
api_url = f"http://{args.host}:{args.port}/generate"
n = args.n
stream = True
prompt = "你叫什么名字?"
prompt = add_prefix(prompt)
print(f"Prompt: {prompt!r}\n", flush=True)
response = post_http_request(prompt, api_url, n, stream)

if stream:
num_printed_lines = 0
for h in get_streaming_response(response):
clear_line(num_printed_lines)
num_printed_lines = 0
for i, line in enumerate(h):
num_printed_lines += 1
print(f"Beam candidate {i}: {line!r}", flush=True)
else:
output = get_response(response)
for i, line in enumerate(output):
print(f"Beam candidate {i}: {line!r}", flush=True)
175 changes: 175 additions & 0 deletions medusa/inference/api_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import argparse
import json
from typing import AsyncGenerator
import torch
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
from medusa.model.medusa_model import MedusaModel
import asyncio
from collections import deque
import uuid
from contextlib import asynccontextmanager

TIMEOUT_KEEP_ALIVE = 5 # seconds.
engine = None
max_batch_size = 5
request_queue = deque()
id2result = {}

async def handle_request(request_data):
request_queue.append(request_data)

async def get_batch_from_queue():
prompts = []
ids = []
if args.origin_model:
request_dict_ = {"temperature":0.5, "max_tokens":150, "top_p": 0.85}
else:
request_dict_ = {"temperature":0.0, "max_tokens":150, "top_p": 0.85}
max_tokens = None
start_time = asyncio.get_event_loop().time() # 获取当前时间
while len(prompts) < max_batch_size:
# 检查是否超时
if asyncio.get_event_loop().time() - start_time >= 0.03:
break
# 如果队列为空,等待1ms再尝试
if not request_queue:
await asyncio.sleep(0.001)
continue
request_dict = request_queue.popleft()
if request_dict.get("max_tokens", None):
if max_tokens:
max_tokens = max(max_tokens, request_dict["max_tokens"])
else:
max_tokens = request_dict["max_tokens"]
prompts.append(request_dict.pop("prompt"))
ids.append(request_dict.pop("unique_id"))
if max_tokens:
request_dict_["max_tokens"] = max_tokens
if len(prompts) > 0 and request_dict.get("temperature", None):
request_dict_["temperature"] = request_dict["temperature"]
if len(prompts) > 0 and request_dict.get("top_p", None):
request_dict_["top_p"] = request_dict["top_p"]
return prompts, ids, request_dict_


async def run_model():
while True:
prompt, ids, request_dict = await get_batch_from_queue()
if len(prompt) >0:
print(f"batch size: {len(prompt)}")
encoded_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
input_ids = encoded_inputs['input_ids'].to(engine.base_model.device)
attention_mask = encoded_inputs['attention_mask'].to(engine.base_model.device)
for request_output in engine.medusa_generate(
input_ids=input_ids,
attention_mask=attention_mask,
temperature=request_dict["temperature"],
max_steps=request_dict["max_tokens"],
top_p=request_dict["top_p"]
):
await asyncio.sleep(0.001)
for index, id in enumerate(ids):
if id2result[id] is None:
id2result[id] = {'text':None, 'sign':None, 'finished':False}
if id2result[id]['text'] != request_output["text"][index]:
id2result[id]['text'] = request_output["text"][index] #full_sentences[index]
id2result[id]['sign'] = str(uuid.uuid4())

for index, id in enumerate(ids):
id2result[id]['finished'] = True
else:
pass

app = FastAPI()

@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)

@app.on_event("startup")
async def startup_event():
asyncio.create_task(run_model())

@app.post("/generate")
async def generate(request: Request) -> Response:
request_dict = await request.json()
unique_id = str(uuid.uuid4())
request_dict["unique_id"] = unique_id
id2result[unique_id] = None
await handle_request(request_dict) ##接收数据放入queue

async def stream_results():
previous_sign = None
while True: ##循环取输出输出
result = id2result.get(unique_id, None)
if result is not None:
if result['sign'] != previous_sign: ##是否更新
full_sentence = result['text']
ret = {"text":[full_sentence]}
previous_sign = result['sign']
yield (json.dumps(ret) + "\0").encode("utf-8")
else:
if result['finished']: ##是否写完
print(f"{unique_id} 全部输出完毕,删除")
id2result.pop(unique_id)
break
await asyncio.sleep(0.001)
else:
await asyncio.sleep(0.001)

return StreamingResponse(stream_results()) ##返回数据


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
parser.add_argument("--origin-model", action="store_true")
parser.add_argument(
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
)
parser.add_argument(
"--load-in-4bit", action="store_true", help="Use 4-bit quantization"
)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")

args = parser.parse_args()
if args.origin_model:
import types
from medusa.model.origin_model import Model,Tokenizer, medusa_generate
from transformers_stream_generator import init_stream_support
init_stream_support()
engine = Model.from_pretrained(args.model)
tokenizer = Tokenizer.from_pretrained(args.model)
engine.medusa_generate = types.MethodType(medusa_generate, engine)
engine.tokenizer = tokenizer
print("启动原始模型")
else:
engine = MedusaModel.from_pretrained(
args.model,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit,
)
tokenizer = engine.get_tokenizer()
print("启动medusa模型")
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile)
98 changes: 98 additions & 0 deletions medusa/inference/inference_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
"""
Chat with a model with command line interface.

Usage:
python3 -m medusa.inference.cli --model <model_name_or_path>
Other commands:
- Type "!!exit" or an empty line to exit.
- Type "!!reset" to start a new conversation.
- Type "!!remove" to remove the last prompt.
- Type "!!regen" to regenerate the last message.
- Type "!!save <filename>" to save the conversation history to a json file.
- Type "!!load <filename>" to load a conversation history from a json file.
"""
import argparse
import os
import re
import sys
import torch
from fastchat.serve.cli import SimpleChatIO, RichChatIO, ProgrammaticChatIO
from fastchat.model.model_adapter import get_conversation_template
from fastchat.conversation import get_conv_template
import json
from medusa.model.medusa_model import MedusaModel
import pdb

def main(args):
prefix = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {0} ASSISTANT:"
# prompt = ["你叫什么名字"]
# prompt = ["你叫什么名字", "中国的首都是哪里呢?"]
prompt = ["openai是家什么公司?", "2+2等于几?"]
prompt = [prefix.format(p) for p in prompt]
model = MedusaModel.from_pretrained(
args.model,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
load_in_8bit=args.load_in_8bit,
load_in_4bit=args.load_in_4bit,
)
tokenizer = model.get_tokenizer()
# 使用tokenizer处理批量输入
encoded_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
# 将编码后的输入移动到模型所在的设备
input_ids = encoded_inputs['input_ids'].to(model.base_model.device)
attention_mask = encoded_inputs['attention_mask'].to(model.base_model.device)
for output in model.medusa_generate(
input_ids,
attention_mask=attention_mask,
temperature=args.temperature,
# temperature=0,
max_steps=args.max_steps
):
print(output['text'])


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
parser.add_argument(
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
)
parser.add_argument(
"--load-in-4bit", action="store_true", help="Use 4-bit quantization"
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument(
"--conv-system-msg", type=str, default=None, help="Conversation system message."
)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-steps", type=int, default=10)
parser.add_argument("--no-history", action="store_true")
parser.add_argument(
"--style",
type=str,
default="simple",
choices=["simple", "rich", "programmatic"],
help="Display style.",
)
parser.add_argument(
"--multiline",
action="store_true",
help="Enable multiline input. Use ESC+Enter for newline.",
)
parser.add_argument(
"--mouse",
action="store_true",
help="[Rich Style]: Enable mouse support for cursor positioning.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Print useful debug information (e.g., prompts)",
)
args = parser.parse_args()
main(args)
29 changes: 29 additions & 0 deletions medusa/inference/origin_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
import pdb
import types
from medusa.model.origin_model import Model,Tokenizer, medusa_generate
from transformers_stream_generator import init_stream_support
init_stream_support()


prefix = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {0} ASSISTANT:"
prompt = ["openai是家什么公司?", "2+2等于几?"]
prompt = [prefix.format(p) for p in prompt]
model_dir='/mnt/wx/.cache/huggingface/hub/models--FasterDecoding--medusa-vicuna-7b-v1.3/snapshots/82ac200bf7502419cb49a9e0adcbebe3d1d293f1/'
model = Model.from_pretrained(model_dir)
tokenizer = Tokenizer.from_pretrained(model_dir)
model_inputs = tokenizer(prompt, padding=True, truncation=True, return_tensors="pt")
# 给实例对象添加方法
model.tokenizer = tokenizer
model.medusa_generate = types.MethodType(medusa_generate, model)
input_ids = model_inputs['input_ids'].to(model.device)
attention_mask = model_inputs['attention_mask'].to(model.device)
generator = model.medusa_generate(input_ids=input_ids,
attention_mask=attention_mask,
temperature=0.1,
max_steps=20,
top_p=0.8)
for token in generator:
print(token['text'])


Loading