diff --git a/python/mlc_llm/bench/api_endpoint.py b/python/mlc_llm/bench/api_endpoint.py index ccce42aec0..1e0c4720f6 100644 --- a/python/mlc_llm/bench/api_endpoint.py +++ b/python/mlc_llm/bench/api_endpoint.py @@ -41,19 +41,23 @@ def __init__( # pylint: disable=too-many-arguments self, host: str, port: int, + backend: str, timeout: Optional[float] = None, include_server_metrics: bool = False, + no_debug_config: bool = False, ) -> None: super().__init__(include_server_metrics=include_server_metrics) import aiohttp # pylint: disable=import-outside-toplevel,import-error + self.backend = backend self.timeout = timeout self.client: aiohttp.ClientSession = None self.url = f"http://{host}:{port}/v1/chat/completions" self.headers = {"Content-Type": "application/json"} if os.getenv("MLC_LLM_API_KEY"): self.headers["Authorization"] = f"Bearer {os.getenv('MLC_LLM_API_KEY')}" + self.no_debug_config = no_debug_config async def __aenter__(self) -> Self: import aiohttp # pylint: disable=import-outside-toplevel,import-error @@ -80,13 +84,28 @@ async def __call__( # pylint: disable=too-many-branches,too-many-statements,too and request_record.chat_cmpl.debug_config.ignore_eos ): payload["ignore_eos"] = True + if not self.no_debug_config: + payload["debug_config"] = {"ignore_eos": True} - print(payload) - - if "response_format" in payload and "json_schema" in payload["response_format"]: - payload["response_format"]["schema"] = payload["response_format"]["json_schema"] - payload["response_format"].pop("json_schema") - + if self.backend == "vllm": + if payload["debug_config"] and "ignore_eos" in payload["debug_config"]: + payload["ignore_eos"] = payload["debug_config"]["ignore_eos"] + payload.pop("debug_config") + if "response_format" in payload: + if "json_schema" in payload["response_format"]: + payload["guided_json"] = json.loads(payload["response_format"]["json_schema"]) + payload["guided_decoding_backend"] = "outlines" + payload.pop("response_format") + elif self.backend == "llama.cpp": + if "response_format" in payload and "schema" in payload["response_format"]: + payload["response_format"]["schema"] = json.loads( + payload["response_format"]["json_schema"] + ) + payload["response_format"].pop("json_schema") + else: + if "response_format" in payload and "json_schema" in payload["response_format"]: + payload["response_format"]["schema"] = payload["response_format"]["json_schema"] + payload["response_format"].pop("json_schema") generated_text = "" first_chunk_output_str = "" time_to_first_token_s = None @@ -447,6 +466,8 @@ async def __call__( # pylint: disable=too-many-branches,too-many-locals,too-man "sglang", "tensorrt-llm", "vllm", + "vllm-chat", + "llama.cpp-chat", ] @@ -454,12 +475,24 @@ def create_api_endpoint(args: argparse.Namespace) -> APIEndPoint: """Create an API endpoint instance with regard to the specified endpoint kind.""" if args.api_endpoint in ["openai", "mlc", "sglang"]: return OpenAIEndPoint(args.host, args.port, args.timeout, args.include_server_metrics) - if args.api_endpoint == "vllm": + if args.api_endpoint in ["vllm", "llama.cpp"]: return OpenAIEndPoint( args.host, args.port, args.timeout, include_server_metrics=False, no_debug_config=True ) if args.api_endpoint == "openai-chat": - return OpenAIChatEndPoint(args.host, args.port, args.timeout, args.include_server_metrics) + return OpenAIChatEndPoint( + args.host, args.port, args.timeout, args.api_endpoint, args.include_server_metrics + ) + if args.api_endpoint in ["vllm-chat", "llama.cpp-chat"]: + return OpenAIChatEndPoint( + args.host, + args.port, + args.api_endpoint[:-5], + args.timeout, + include_server_metrics=False, + no_debug_config=True, + ) + if args.api_endpoint == "tensorrt-llm": return TensorRTLLMEndPoint(args.host, args.port, args.timeout) raise ValueError(f'Unrecognized endpoint "{args.api_endpoint}"') diff --git a/python/mlc_llm/bench/dataset.py b/python/mlc_llm/bench/dataset.py index f555a80547..9965007705 100644 --- a/python/mlc_llm/bench/dataset.py +++ b/python/mlc_llm/bench/dataset.py @@ -248,7 +248,12 @@ def __init__(self, tokenizer: AutoTokenizer) -> None: self.dataset = [] for data in raw_dataset: data = self._process_data(data) - messages = data["prompt"] + messages = [ + { + "content": data["prompt"][0]["content"] + " " + data["prompt"][1]["content"], + "role": data["prompt"][1]["role"], + }, + ] schema = { "type": "json_object", "schema": data["schema"], diff --git a/python/mlc_llm/bench/request_processor.py b/python/mlc_llm/bench/request_processor.py index dd9c9f6150..c6ded2d818 100644 --- a/python/mlc_llm/bench/request_processor.py +++ b/python/mlc_llm/bench/request_processor.py @@ -131,7 +131,7 @@ def __call__(self, request_records: List[RequestRecord]) -> List[RequestRecord]: request_record.chat_cmpl.top_p = self.top_p request_record.chat_cmpl.frequency_penalty = 0.0 request_record.chat_cmpl.presence_penalty = 0.0 - request_record.chat_cmpl.tool_choice = "none" + request_record.chat_cmpl.tool_choice = None if self.ignore_eos: request_record.chat_cmpl.debug_config = DebugConfig(ignore_eos=True) return request_records