diff --git a/python/mlc_llm/bench/api_endpoint.py b/python/mlc_llm/bench/api_endpoint.py index a44b6da690..1e0c4720f6 100644 --- a/python/mlc_llm/bench/api_endpoint.py +++ b/python/mlc_llm/bench/api_endpoint.py @@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments self, host: str, port: int, + backend: str, timeout: Optional[float] = None, include_server_metrics: bool = False, + no_debug_config: bool = False, ) -> None: super().__init__(include_server_metrics=include_server_metrics) import aiohttp # pylint: disable=import-outside-toplevel,import-error + self.backend = backend self.timeout = timeout self.client: aiohttp.ClientSession = None self.url = f"http://{host}:{port}/v1/chat/completions" self.headers = {"Content-Type": "application/json"} if os.getenv("MLC_LLM_API_KEY"): self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}" + self.no_debug_config = no_debug_config async def __aenter__(self) -> Self: import aiohttp # pylint: disable=import-outside-toplevel,import-error @@ -67,7 +71,7 @@ async def __aexit__(self, exc_type, exc_value, tb) -> None: async def __call__( # pylint: disable=too-many-branches,too-many-statements,too-many-locals self, request_record: RequestRecord ) -> RequestRecord: - payload = request_record.chat_cmpl.model_dump() + payload = request_record.chat_cmpl.model_dump(exclude_unset=True, exclude_none=True) if self.timeout is not None and "timeout" not in payload: payload["timeout"] = self.timeout if self.include_server_metrics: @@ -80,7 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too and request_record.chat_cmpl.debug_config.ignore_eos ): payload["ignore_eos"] = True + if not self.no_debug_config: + payload["debug_config"] = {"ignore_eos": True} + if self.backend == "vllm": + if payload["debug_config"] and "ignore_eos" in payload["debug_config"]: + payload["ignore_eos"] = payload["debug_config"]["ignore_eos"] + payload.pop("debug_config") + if "response_format" in payload: + if "json_schema" in payload["response_format"]: + payload["guided_json"] = json.loads(payload["response_format"]["json_schema"]) + payload["guided_decoding_backend"] = "outlines" + payload.pop("response_format") + elif self.backend == "llama.cpp": + if "response_format" in payload and "schema" in payload["response_format"]: + payload["response_format"]["schema"] = json.loads( + payload["response_format"]["json_schema"] + ) + payload["response_format"].pop("json_schema") + else: + if "response_format" in payload and "json_schema" in payload["response_format"]: + payload["response_format"]["schema"] = payload["response_format"]["json_schema"] + payload["response_format"].pop("json_schema") generated_text = "" first_chunk_output_str = "" time_to_first_token_s = None @@ -441,6 +466,8 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man "sglang", "tensorrt-llm", "vllm", + "vllm-chat", + "llama.cpp-chat", ] @@ -448,12 +475,24 @@ def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint: """Create an API endpoint instance with regard to the specified endpoint kind.""" if args.api_endpoint in ["openai", "mlc", "sglang"]: return OpenAIEndPoint(args.host, args.port, args.timeout, args.include_server_metrics) - if args.api_endpoint == "vllm": + if args.api_endpoint in ["vllm", "llama.cpp"]: return OpenAIEndPoint( args.host, args.port, args.timeout, include_server_metrics=False, no_debug_config=True ) if args.api_endpoint == "openai-chat": - return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics) + return OpenAIChatEndPoint( + args.host, args.port, args.timeout, args.api_endpoint, args.include_server_metrics + ) + if args.api_endpoint in ["vllm-chat", "llama.cpp-chat"]: + return OpenAIChatEndPoint( + args.host, + args.port, + args.api_endpoint[:-5], + args.timeout, + include_server_metrics=False, + no_debug_config=True, + ) + if args.api_endpoint == "tensorrt-llm": return TensorRTLLMEndPoint(args.host, args.port, args.timeout) raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"') diff --git a/python/mlc_llm/bench/dataset.py b/python/mlc_llm/bench/dataset.py index a66f3011d6..9965007705 100644 --- a/python/mlc_llm/bench/dataset.py +++ b/python/mlc_llm/bench/dataset.py @@ -243,11 +243,17 @@ class JSONModeEvalDataset(Dataset): # pylint: disable=too-few-public-methods """The dataset class for JSON dataset.""" def __init__(self, tokenizer: AutoTokenizer) -> None: - raw_dataset = load_dataset("NousResearch/json-mode-eval") + raw_dataset = load_dataset("NousResearch/json-mode-eval", split="train") self.tokenizer = tokenizer self.dataset = [] - for data in raw_dataset["train"]: - messages = data["prompt"] + for data in raw_dataset: + data = self._process_data(data) + messages = [ + { + "content": data["prompt"][0]["content"] + " " + data["prompt"][1]["content"], + "role": data["prompt"][1]["role"], + }, + ] schema = { "type": "json_object", "schema": data["schema"], @@ -259,6 +265,42 @@ def __init__(self, tokenizer: AutoTokenizer) -> None: ) self.dataset.append((messages, schema, num_tokens)) + def _process_data(self, data): + data["prompt"][0]["content"] = data["prompt"][0]["content"].replace( + ", 'format': 'email'", "" + ) + data["schema"] = data["schema"].replace(', "format": "email"', "") + + data["prompt"][0]["content"] = data["prompt"][0]["content"].replace( + ", 'pattern': '\\\\d{5}'", "" + ) + data["schema"] = data["schema"].replace(', "pattern": "\\\\d{5}"', "") + + schema_str = data["schema"] + schema = json.loads(schema_str) + new_schema = None + if "type" not in schema: + if len(schema.keys()) == 1: + key = list(schema.keys())[0] + new_schema = {"title": key, **schema[key]} + else: + new_schema = {"type": "object", **schema} + if new_schema is None: + return data + return { + "prompt": [ + { + "content": "You are a helpful assistant that answers in JSON. " + "Here's the json schema you must adhere to:" + f"\n\n{new_schema}\n\n", + "role": "system", + }, + data["prompt"][1], + ], + "completion": data["completion"], + "schema": json.dumps(new_schema), + } + def generate_request_records( self, input_len: Optional[int], @@ -288,6 +330,9 @@ def generate_request_records( model="", max_tokens=output_length, response_format=schema, + debug_config=DebugConfig( + grammar_execution_mode="constraint", + ), ), metrics=Metrics( success=False, diff --git a/python/mlc_llm/bench/request_processor.py b/python/mlc_llm/bench/request_processor.py index dd9c9f6150..c6ded2d818 100644 --- a/python/mlc_llm/bench/request_processor.py +++ b/python/mlc_llm/bench/request_processor.py @@ -131,7 +131,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: request_record.chat_cmpl.top_p = self.top_p request_record.chat_cmpl.frequency_penalty = 0.0 request_record.chat_cmpl.presence_penalty = 0.0 - request_record.chat_cmpl.tool_choice = "none" + request_record.chat_cmpl.tool_choice = None if self.ignore_eos: request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True) return request_records