diff --git a/python/mlc_llm/__init__.py b/python/mlc_llm/__init__.py index 4843c6766d..66285cea4e 100644 --- a/python/mlc_llm/__init__.py +++ b/python/mlc_llm/__init__.py @@ -4,6 +4,5 @@ """ from . import protocol, serve -from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig from .libinfo import __version__ from .serve import AsyncMLCEngine, MLCEngine diff --git a/python/mlc_llm/protocol/__init__.py b/python/mlc_llm/protocol/__init__.py index 8cd2a69ca7..b430746477 100644 --- a/python/mlc_llm/protocol/__init__.py +++ b/python/mlc_llm/protocol/__init__.py @@ -1,4 +1,9 @@ -"""Definitions of pydantic models for API entry points and configurations""" -from . import openai_api_protocol +"""Definitions of pydantic models for API entry points and configurations -RequestProtocol = openai_api_protocol.CompletionRequest +Note +---- +We use the following convention + +- filename_protocol If the classes can appear in an API endpoint +- filename_config For other config classes +""" diff --git a/python/mlc_llm/protocol/generation_config.py b/python/mlc_llm/protocol/generation_config.py new file mode 100644 index 0000000000..6cd5e82cf0 --- /dev/null +++ b/python/mlc_llm/protocol/generation_config.py @@ -0,0 +1,32 @@ +"""Low-level generation config class""" +# pylint: disable=missing-class-docstring, disable=too-many-instance-attributes +from typing import Dict, List, Optional + +from pydantic import BaseModel + +from .debug_protocol import DebugConfig +from .openai_api_protocol import RequestResponseFormat + + +class GenerationConfig(BaseModel): # pylint: + """The generation configuration dataclass. + + This is a config class used by Engine internally. + """ + + n: int = 1 + temperature: Optional[float] = None + top_p: Optional[float] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + logprobs: bool = False + top_logprobs: int = 0 + logit_bias: Optional[Dict[int, float]] = None + # internally we use -1 to represent infinite + max_tokens: int = -1 + seed: Optional[int] = None + stop_strs: Optional[List[str]] = None + stop_token_ids: Optional[List[int]] = None + response_format: Optional[RequestResponseFormat] = None + debug_config: Optional[Optional[DebugConfig]] = None diff --git a/python/mlc_llm/serve/__init__.py b/python/mlc_llm/serve/__init__.py index 4ef4470399..6b122bdf64 100644 --- a/python/mlc_llm/serve/__init__.py +++ b/python/mlc_llm/serve/__init__.py @@ -2,7 +2,7 @@ # Load MLC LLM library by importing base from .. import base -from .config import DebugConfig, EngineConfig, GenerationConfig +from .config import EngineConfig from .data import Data, ImageData, RequestStreamOutput, TextData, TokenData from .engine import AsyncMLCEngine, MLCEngine from .grammar import BNFGrammar, GrammarStateMatcher diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index f4fadf0dae..bf79bb672f 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -2,147 +2,7 @@ import json from dataclasses import asdict, dataclass, field -from typing import Dict, List, Literal, Optional, Tuple, Union - - -@dataclass -class ResponseFormat: - """The response format dataclass. - - Parameters - ---------- - type : Literal["text", "json_object"] - The type of response format. Default: "text". - - schema : Optional[str] - The JSON schema string for the JSON response format. If None, a legal json string without - special restrictions will be generated. - - Could be specified when the response format is "json_object". Default: None. - """ - - type: Literal["text", "json_object"] = "text" - schema: Optional[str] = None - - def __post_init__(self): - if self.schema is not None and self.type != "json_object": - raise ValueError("JSON schema is only supported in JSON response format") - - -@dataclass -class DebugConfig: - """The debug configuration dataclass.Parameters - ---------- - ignore_eos : bool - When it is true, ignore the eos token and generate tokens until `max_tokens`. - Default is set to False. - - pinned_system_prompt : bool - Whether the input and generated data pinned in engine. Default is set to False. - This can be used for system prompt or other purpose, if the data is aimed to be - kept all the time. - - special_request: Optional[string] - Special requests to send to engine - """ - - ignore_eos: bool = False - pinned_system_prompt: bool = False - special_request: Optional[Literal["query_engine_metrics"]] = None - - -@dataclass -class GenerationConfig: # pylint: disable=too-many-instance-attributes - """The generation configuration dataclass. - - Parameters - ---------- - n : int - How many chat completion choices to generate for each input message. - - temperature : Optional[float] - The value that applies to logits and modulates the next token probabilities. - - top_p : Optional[float] - In sampling, only the most probable tokens with probabilities summed up to - `top_p` are kept for sampling. - - frequency_penalty : Optional[float] - Positive values penalize new tokens based on their existing frequency - in the text so far, decreasing the model's likelihood to repeat the same - line verbatim. - - presence_penalty : Optional[float] - Positive values penalize new tokens based on whether they appear in the text - so far, increasing the model's likelihood to talk about new topics. - - repetition_penalty : float - The penalty term that applies to logits to control token repetition in generation. - It will be suppressed when any of frequency_penalty and presence_penalty is - non-zero. - - logprobs : bool - Whether to return log probabilities of the output tokens or not. - If true, the log probabilities of each output token will be returned. - - top_logprobs : int - An integer between 0 and 5 specifying the number of most likely - tokens to return at each token position, each with an associated - log probability. - `logprobs` must be set to True if this parameter is used. - - logit_bias : Optional[Dict[int, float]] - The bias logit value added to selected tokens prior to sampling. - - max_tokens : Optional[int] - The maximum number of generated tokens, - or None, in which case the generation will not stop - until exceeding model capability or hit any stop criteria. - - seed : Optional[int] - The random seed of the generation. - The seed will be a random value if not specified. - - stop_strs : List[str] - The list of strings that mark the end of generation. - - stop_token_ids : List[int] - The list of token ids that mark the end of generation. - - response_format : ResponseFormat - The response format of the generation output. - - debug_config : Optional[DebugConfig] - The optional debug configuration. - """ - - n: int = 1 - temperature: Optional[float] = None - top_p: Optional[float] = None - frequency_penalty: Optional[float] = None - presence_penalty: Optional[float] = None - repetition_penalty: float = 1.0 - logprobs: bool = False - top_logprobs: int = 0 - logit_bias: Optional[Dict[int, float]] = field(default_factory=dict) # type: ignore - - max_tokens: Optional[int] = 128 - seed: Optional[int] = None - stop_strs: List[str] = field(default_factory=list) - stop_token_ids: List[int] = field(default_factory=list) - - response_format: ResponseFormat = field(default_factory=ResponseFormat) - - debug_config: Optional[DebugConfig] = field(default_factory=DebugConfig) - - def asjson(self) -> str: - """Return the config in string of JSON format.""" - return json.dumps(asdict(self)) - - @staticmethod - def from_json(json_str: str) -> "GenerationConfig": - """Construct a config from JSON string.""" - return GenerationConfig(**json.loads(json_str)) +from typing import List, Literal, Optional, Tuple, Union @dataclass diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index e072d1028d..012f450bb2 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -22,8 +22,9 @@ from tvm.runtime import Device from mlc_llm.protocol import debug_protocol, openai_api_protocol +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig from mlc_llm.streamer import TextStreamer from mlc_llm.support import logging @@ -1372,7 +1373,9 @@ async def _generate( # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = self._ffi["create_request"](request_id, input_data, generation_config.asjson()) + request = self._ffi["create_request"]( + request_id, input_data, generation_config.model_dump_json() + ) # Create the unique async request stream of the request. stream = engine_base.AsyncRequestStream() @@ -1898,7 +1901,9 @@ def _generate( # pylint: disable=too-many-locals # Create the request with the given id, input data, generation # config and the created callback. input_data = engine_utils.convert_prompts_to_data(prompt) - request = self._ffi["create_request"](request_id, input_data, generation_config.asjson()) + request = self._ffi["create_request"]( + request_id, input_data, generation_config.model_dump_json() + ) # Record the stream in the tracker self.state.sync_output_queue = queue.Queue() diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index d8b1842c0b..8aa8d52b97 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -18,9 +18,10 @@ from mlc_llm.protocol import openai_api_protocol from mlc_llm.protocol.conversation_protocol import Conversation +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.protocol.mlc_chat_config import MLCChatConfig from mlc_llm.serve import data, engine_utils -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.event_trace_recorder import EventTraceRecorder from mlc_llm.streamer import TextStreamer from mlc_llm.support import download_cache, logging diff --git a/python/mlc_llm/serve/engine_utils.py b/python/mlc_llm/serve/engine_utils.py index c2d686d583..6ccbc0e621 100644 --- a/python/mlc_llm/serve/engine_utils.py +++ b/python/mlc_llm/serve/engine_utils.py @@ -3,10 +3,13 @@ import uuid from typing import Any, Callable, Dict, List, Optional, Union -from mlc_llm.protocol import RequestProtocol, error_protocol, openai_api_protocol +from mlc_llm.protocol import error_protocol, openai_api_protocol +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data -from .config import DebugConfig, GenerationConfig, ResponseFormat +RequestProtocol = Union[ + openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest +] def get_unsupported_fields(request: RequestProtocol) -> List[str]: @@ -20,9 +23,7 @@ def get_unsupported_fields(request: RequestProtocol) -> List[str]: raise RuntimeError("Cannot reach here") -def openai_api_get_generation_config( - request: Union[openai_api_protocol.CompletionRequest, openai_api_protocol.ChatCompletionRequest] -) -> Dict[str, Any]: +def openai_api_get_generation_config(request: RequestProtocol) -> Dict[str, Any]: """Create the generation config from the given request.""" kwargs: Dict[str, Any] = {} arg_names = [ @@ -36,6 +37,8 @@ def openai_api_get_generation_config( "top_logprobs", "logit_bias", "seed", + "response_format", + "debug_config", ] for arg_name in arg_names: kwargs[arg_name] = getattr(request, arg_name) @@ -45,12 +48,6 @@ def openai_api_get_generation_config( kwargs["max_tokens"] = -1 if request.stop is not None: kwargs["stop_strs"] = [request.stop] if isinstance(request.stop, str) else request.stop - if request.response_format is not None: - kwargs["response_format"] = ResponseFormat( - **request.response_format.model_dump(by_alias=True) - ) - if request.debug_config is not None: - kwargs["debug_config"] = DebugConfig(**request.debug_config.model_dump()) return kwargs diff --git a/python/mlc_llm/serve/request.py b/python/mlc_llm/serve/request.py index d9260e6598..10c2e0577d 100644 --- a/python/mlc_llm/serve/request.py +++ b/python/mlc_llm/serve/request.py @@ -4,8 +4,9 @@ import tvm._ffi from tvm.runtime import Object +from mlc_llm.protocol.generation_config import GenerationConfig + from . import _ffi_api -from .config import GenerationConfig from .data import Data @@ -29,6 +30,6 @@ def inputs(self) -> List[Data]: @property def generation_config(self) -> GenerationConfig: """The generation config of the request.""" - return GenerationConfig.from_json( + return GenerationConfig.model_validate_json( _ffi_api.RequestGetGenerationConfigJSON(self) # type: ignore # pylint: disable=no-member ) diff --git a/python/mlc_llm/serve/sync_engine.py b/python/mlc_llm/serve/sync_engine.py index 460bc4d52e..5b5fd9cd98 100644 --- a/python/mlc_llm/serve/sync_engine.py +++ b/python/mlc_llm/serve/sync_engine.py @@ -13,8 +13,9 @@ import tvm +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve import data -from mlc_llm.serve.config import EngineConfig, GenerationConfig +from mlc_llm.serve.config import EngineConfig from mlc_llm.serve.engine_base import ( EngineMetrics, _check_engine_config, @@ -307,7 +308,7 @@ def create_request( """ if not isinstance(inputs, list): inputs = [inputs] - return self._ffi["create_request"](request_id, inputs, generation_config.asjson()) + return self._ffi["create_request"](request_id, inputs, generation_config.model_dump_json()) def add_request(self, request: Request) -> None: """Add a new request to the engine. diff --git a/python/mlc_llm/testing/debug_chat.py b/python/mlc_llm/testing/debug_chat.py index 6aacce1faf..6f25328c8f 100644 --- a/python/mlc_llm/testing/debug_chat.py +++ b/python/mlc_llm/testing/debug_chat.py @@ -373,9 +373,6 @@ def generate( generate_length : int How many tokens to generate. - - generation_config : Optional[GenerationConfig] - Will be used to override the GenerationConfig in ``mlc-chat-config.json``. """ out_tokens = [] diff --git a/tests/python/serve/evaluate_engine.py b/tests/python/serve/evaluate_engine.py index 608f69dd4c..7767c30abc 100644 --- a/tests/python/serve/evaluate_engine.py +++ b/tests/python/serve/evaluate_engine.py @@ -4,7 +4,7 @@ import random from typing import List, Tuple -from mlc_llm.serve import GenerationConfig +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine diff --git a/tests/python/serve/test_serve_async_engine.py b/tests/python/serve/test_serve_async_engine.py index 1884359718..993e5b60b3 100644 --- a/tests/python/serve/test_serve_async_engine.py +++ b/tests/python/serve/test_serve_async_engine.py @@ -3,7 +3,8 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncMLCEngine, EngineConfig, GenerationConfig +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, EngineConfig from mlc_llm.testing import require_test_model prompts = [ @@ -20,7 +21,7 @@ ] -@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") async def test_engine_generate(model: str): # Create engine async_engine = AsyncMLCEngine( @@ -48,9 +49,12 @@ async def generate_task( async for delta_outputs in async_engine._generate( prompt, generation_cfg, request_id=request_id ): - assert len(delta_outputs) == generation_cfg.n - for i, delta_output in enumerate(delta_outputs): - output_texts[rid][i] += delta_output.delta_text + if len(delta_outputs) == generation_cfg.n: + for i, delta_output in enumerate(delta_outputs): + output_texts[rid][i] += delta_output.delta_text + else: + assert len(delta_outputs) == 1 + assert len(delta_outputs[0].request_final_usage_json_str) != 0 tasks = [ asyncio.create_task( @@ -75,7 +79,7 @@ async def generate_task( del async_engine -@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") async def test_chat_completion(model: str): # Create engine async_engine = AsyncMLCEngine( @@ -126,7 +130,7 @@ async def generate_task(prompt: str, request_id: str): del async_engine -@require_test_model("Llama-2-7b-chat-hf-q0f16-MLC") +@require_test_model("Llama-2-7b-chat-hf-q4f16_1-MLC") async def test_chat_completion_non_stream(model: str): # Create engine async_engine = AsyncMLCEngine( diff --git a/tests/python/serve/test_serve_async_engine_spec.py b/tests/python/serve/test_serve_async_engine_spec.py index c3d4c37756..476d970e1c 100644 --- a/tests/python/serve/test_serve_async_engine_spec.py +++ b/tests/python/serve/test_serve_async_engine_spec.py @@ -3,7 +3,8 @@ import asyncio from typing import List -from mlc_llm.serve import AsyncMLCEngine, EngineConfig, GenerationConfig +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import AsyncMLCEngine, EngineConfig from mlc_llm.testing import require_test_model prompts = [ diff --git a/tests/python/serve/test_serve_engine.py b/tests/python/serve/test_serve_engine.py index 670d33b236..899629a448 100644 --- a/tests/python/serve/test_serve_engine.py +++ b/tests/python/serve/test_serve_engine.py @@ -2,7 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import EngineConfig, GenerationConfig, MLCEngine +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import EngineConfig, MLCEngine from mlc_llm.testing import require_test_model prompts = [ diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index d85ab8e762..13d12f5a29 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -7,8 +7,9 @@ import pytest from pydantic import BaseModel -from mlc_llm.serve import AsyncMLCEngine, GenerationConfig -from mlc_llm.serve.config import ResponseFormat +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.protocol.openai_api_protocol import RequestResponseFormat as ResponseFormat +from mlc_llm.serve import AsyncMLCEngine from mlc_llm.serve.sync_engine import SyncMLCEngine from mlc_llm.testing import require_test_model diff --git a/tests/python/serve/test_serve_engine_image.py b/tests/python/serve/test_serve_engine_image.py index b1cdf1fcea..0fdf141faf 100644 --- a/tests/python/serve/test_serve_engine_image.py +++ b/tests/python/serve/test_serve_engine_image.py @@ -1,7 +1,8 @@ import json from pathlib import Path -from mlc_llm.serve import GenerationConfig, data +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import data from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine diff --git a/tests/python/serve/test_serve_engine_prefix_cache.py b/tests/python/serve/test_serve_engine_prefix_cache.py index ca55540fff..0a32c04b11 100644 --- a/tests/python/serve/test_serve_engine_prefix_cache.py +++ b/tests/python/serve/test_serve_engine_prefix_cache.py @@ -1,4 +1,5 @@ -from mlc_llm.serve import DebugConfig, GenerationConfig +from mlc_llm.protocol.debug_protocol import DebugConfig +from mlc_llm.protocol.generation_config import GenerationConfig from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine from mlc_llm.testing import require_test_model diff --git a/tests/python/serve/test_serve_engine_rnn.py b/tests/python/serve/test_serve_engine_rnn.py index 090c06dbc3..194e7ec35d 100644 --- a/tests/python/serve/test_serve_engine_rnn.py +++ b/tests/python/serve/test_serve_engine_rnn.py @@ -2,7 +2,8 @@ # pylint: disable=too-many-arguments,too-many-locals,unused-argument,unused-variable from typing import List -from mlc_llm.serve import EngineConfig, GenerationConfig, MLCEngine +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import EngineConfig, MLCEngine prompts = [ "What is the meaning of life?", diff --git a/tests/python/serve/test_serve_engine_spec.py b/tests/python/serve/test_serve_engine_spec.py index b37e7c8051..61a40476ae 100644 --- a/tests/python/serve/test_serve_engine_spec.py +++ b/tests/python/serve/test_serve_engine_spec.py @@ -4,7 +4,8 @@ import numpy as np -from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine from mlc_llm.testing import require_test_model diff --git a/tests/python/serve/test_serve_sync_engine.py b/tests/python/serve/test_serve_sync_engine.py index 8dbc60925e..b889628592 100644 --- a/tests/python/serve/test_serve_sync_engine.py +++ b/tests/python/serve/test_serve_sync_engine.py @@ -4,7 +4,8 @@ import numpy as np -from mlc_llm.serve import GenerationConfig, Request, RequestStreamOutput, data +from mlc_llm.protocol.generation_config import GenerationConfig +from mlc_llm.serve import Request, RequestStreamOutput, data from mlc_llm.serve.sync_engine import EngineConfig, SyncMLCEngine from mlc_llm.testing import require_test_model