Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bench] Add support for multiple backend #3037

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions python/mlc_llm/bench/api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments
self,
host: str,
port: int,
backend: str,
timeout: Optional[float] = None,
include_server_metrics: bool = False,
no_debug_config: bool = False,
) -> None:
super().__init__(include_server_metrics=include_server_metrics)

import aiohttp # pylint: disable=import-outside-toplevel,import-error

self.backend = backend
self.timeout = timeout
self.client: aiohttp.ClientSession = None
self.url = f"http://{host}:{port}/v1/chat/completions"
self.headers = {"Content-Type": "application/json"}
if os.getenv("MLC_LLM_API_KEY"):
self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}"
self.no_debug_config = no_debug_config

async def __aenter__(self) -> Self:
import aiohttp # pylint: disable=import-outside-toplevel,import-error
Expand All @@ -67,7 +71,7 @@ async def __aexit__(self, exc_type, exc_value, tb) -> None:
async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals
self, request_record: RequestRecord
) -> RequestRecord:
payload = request_record.chat_cmpl.model_dump()
payload = request_record.chat_cmpl.model_dump(exclude_unset=True, exclude_none=True)
if self.timeout is not None and "timeout" not in payload:
payload["timeout"] = self.timeout
if self.include_server_metrics:
Expand All @@ -80,7 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too
and request_record.chat_cmpl.debug_config.ignore_eos
):
payload["ignore_eos"] = True
if not self.no_debug_config:
payload["debug_config"] = {"ignore_eos": True}

if self.backend == "vllm":
if payload["debug_config"] and "ignore_eos" in payload["debug_config"]:
payload["ignore_eos"] = payload["debug_config"]["ignore_eos"]
payload.pop("debug_config")
if "response_format" in payload:
if "json_schema" in payload["response_format"]:
payload["guided_json"] = json.loads(payload["response_format"]["json_schema"])
payload["guided_decoding_backend"] = "outlines"
payload.pop("response_format")
elif self.backend == "llama.cpp":
if "response_format" in payload and "schema" in payload["response_format"]:
payload["response_format"]["schema"] = json.loads(
payload["response_format"]["json_schema"]
)
payload["response_format"].pop("json_schema")
else:
if "response_format" in payload and "json_schema" in payload["response_format"]:
payload["response_format"]["schema"] = payload["response_format"]["json_schema"]
payload["response_format"].pop("json_schema")
generated_text = ""
first_chunk_output_str = ""
time_to_first_token_s = None
Expand Down Expand Up @@ -441,19 +466,33 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man
"sglang",
"tensorrt-llm",
"vllm",
"vllm-chat",
"llama.cpp-chat",
]


def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint:
"""Create an API endpoint instance with regard to the specified endpoint kind."""
if args.api_endpoint in ["openai", "mlc", "sglang"]:
return OpenAIEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)
if args.api_endpoint == "vllm":
if args.api_endpoint in ["vllm", "llama.cpp"]:
return OpenAIEndPoint(
args.host, args.port, args.timeout, include_server_metrics=False, no_debug_config=True
)
if args.api_endpoint == "openai-chat":
return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics)
return OpenAIChatEndPoint(
args.host, args.port, args.timeout, args.api_endpoint, args.include_server_metrics
)
if args.api_endpoint in ["vllm-chat", "llama.cpp-chat"]:
return OpenAIChatEndPoint(
args.host,
args.port,
args.api_endpoint[:-5],
args.timeout,
include_server_metrics=False,
no_debug_config=True,
)

if args.api_endpoint == "tensorrt-llm":
return TensorRTLLMEndPoint(args.host, args.port, args.timeout)
raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"')
51 changes: 48 additions & 3 deletions python/mlc_llm/bench/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,17 @@ class JSONModeEvalDataset(Dataset): # pylint: disable=too-few-public-methods
"""The dataset class for JSON dataset."""

def __init__(self, tokenizer: AutoTokenizer) -> None:
raw_dataset = load_dataset("NousResearch/json-mode-eval")
raw_dataset = load_dataset("NousResearch/json-mode-eval", split="train")
self.tokenizer = tokenizer
self.dataset = []
for data in raw_dataset["train"]:
messages = data["prompt"]
for data in raw_dataset:
data = self._process_data(data)
messages = [
{
"content": data["prompt"][0]["content"] + " " + data["prompt"][1]["content"],
"role": data["prompt"][1]["role"],
},
]
schema = {
"type": "json_object",
"schema": data["schema"],
Expand All @@ -259,6 +265,42 @@ def __init__(self, tokenizer: AutoTokenizer) -> None:
)
self.dataset.append((messages, schema, num_tokens))

def _process_data(self, data):
data["prompt"][0]["content"] = data["prompt"][0]["content"].replace(
", 'format': 'email'", ""
)
data["schema"] = data["schema"].replace(', "format": "email"', "")

data["prompt"][0]["content"] = data["prompt"][0]["content"].replace(
", 'pattern': '\\\\d{5}'", ""
)
data["schema"] = data["schema"].replace(', "pattern": "\\\\d{5}"', "")

schema_str = data["schema"]
schema = json.loads(schema_str)
new_schema = None
if "type" not in schema:
if len(schema.keys()) == 1:
key = list(schema.keys())[0]
new_schema = {"title": key, **schema[key]}
else:
new_schema = {"type": "object", **schema}
if new_schema is None:
return data
return {
"prompt": [
{
"content": "You are a helpful assistant that answers in JSON. "
"Here's the json schema you must adhere to:"
f"\n<schema>\n{new_schema}\n</schema>\n",
"role": "system",
},
data["prompt"][1],
],
"completion": data["completion"],
"schema": json.dumps(new_schema),
}

def generate_request_records(
self,
input_len: Optional[int],
Expand Down Expand Up @@ -288,6 +330,9 @@ def generate_request_records(
model="",
max_tokens=output_length,
response_format=schema,
debug_config=DebugConfig(
grammar_execution_mode="constraint",
),
),
metrics=Metrics(
success=False,
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/bench/request_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]:
request_record.chat_cmpl.top_p = self.top_p
request_record.chat_cmpl.frequency_penalty = 0.0
request_record.chat_cmpl.presence_penalty = 0.0
request_record.chat_cmpl.tool_choice = "none"
request_record.chat_cmpl.tool_choice = None
if self.ignore_eos:
request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True)
return request_records
Expand Down
Loading