diff --git a/engines/python/setup/djl_python/chat_completions/chat_properties.py b/engines/python/setup/djl_python/chat_completions/chat_properties.py index 2630b8248..79ce07e80 100644 --- a/engines/python/setup/djl_python/chat_completions/chat_properties.py +++ b/engines/python/setup/djl_python/chat_completions/chat_properties.py @@ -12,7 +12,7 @@ # the specific language governing permissions and limitations under the License. from typing import Optional, Union, List, Dict -from pydantic.v1 import BaseModel, Field, validator, root_validator +from pydantic import BaseModel, Field, validator, root_validator class ChatProperties(BaseModel): @@ -22,25 +22,28 @@ class ChatProperties(BaseModel): """ messages: List[Dict[str, str]] - model: Optional[str] # UNUSED - frequency_penalty: Optional[float] = 0 - logit_bias: Optional[dict] = None - logprobs: Optional[bool] = False - top_logprobs: Optional[int] # Currently only support 1 - max_new_tokens: Optional[int] = Field(alias="max_tokens") - n: Optional[int] = 1 # Currently only support 1 - presence_penalty: Optional[float] = 0 - seed: Optional[int] - stop_sequences: Optional[Union[str, list]] = Field(alias="stop") - temperature: Optional[int] = 1 - top_p: Optional[int] = 1 - user: Optional[str] + model: Optional[str] = Field(default=None, exclude=True) # Unused + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = Field(default=None, exclude=True) + logprobs: Optional[bool] = Field(default=False, exclude=True) + top_logprobs: Optional[int] = Field(default=None, + serialization_alias="logprobs") + max_tokens: Optional[int] = Field(default=None, + serialization_alias="max_new_tokens") + n: Optional[int] = Field(default=1, exclude=True) + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stream: Optional[bool] = False + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + user: Optional[str] = Field(default=None, exclude=True) @validator('messages', pre=True) def validate_messages( cls, messages: List[Dict[str, str]]) -> List[Dict[str, str]]: if messages is None: - return messages + return None for message in messages: if not ("role" in message and "content" in message): @@ -52,17 +55,31 @@ def validate_messages( @validator('frequency_penalty', pre=True) def validate_frequency_penalty(cls, frequency_penalty: float) -> float: if frequency_penalty is None: - return frequency_penalty + return None frequency_penalty = float(frequency_penalty) if frequency_penalty < -2.0 or frequency_penalty > 2.0: raise ValueError("frequency_penalty must be between -2.0 and 2.0.") return frequency_penalty - @validator('top_logprobs', pre=True) - def validate_top_logprobs(cls, top_logprobs: float) -> float: + @validator('logit_bias', pre=True) + def validate_logit_bias(cls, logit_bias: Dict[str, float]): + if logit_bias is None: + return None + + for token_id, bias in logit_bias.items(): + if bias < -100.0 or bias > 100.0: + raise ValueError( + "logit_bias value must be between -100 and 100.") + return logit_bias + + @validator('top_logprobs') + def validate_top_logprobs(cls, top_logprobs: int, values): if top_logprobs is None: - return top_logprobs + return None + + if not values.get('logprobs'): + return None top_logprobs = int(top_logprobs) if top_logprobs < 0 or top_logprobs > 20: @@ -72,7 +89,7 @@ def validate_top_logprobs(cls, top_logprobs: float) -> float: @validator('presence_penalty', pre=True) def validate_presence_penalty(cls, presence_penalty: float) -> float: if presence_penalty is None: - return presence_penalty + return None presence_penalty = float(presence_penalty) if presence_penalty < -2.0 or presence_penalty > 2.0: @@ -82,9 +99,9 @@ def validate_presence_penalty(cls, presence_penalty: float) -> float: @validator('temperature', pre=True) def validate_temperature(cls, temperature: float) -> float: if temperature is None: - return temperature + return None temperature = float(temperature) - if temperature < 0 or temperature > 2: + if temperature < 0.0 or temperature > 2.0: raise ValueError("temperature must be between 0 and 2.") return temperature diff --git a/engines/python/setup/djl_python/chat_completions/chat_utils.py b/engines/python/setup/djl_python/chat_completions/chat_utils.py index 0a86e8a7a..e0cb94649 100644 --- a/engines/python/setup/djl_python/chat_completions/chat_utils.py +++ b/engines/python/setup/djl_python/chat_completions/chat_utils.py @@ -29,15 +29,11 @@ def parse_chat_completions_request(inputs: map, is_rolling_batch: bool, f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, " f"please ensure that your tokenizer supports chat templates.") chat_params = ChatProperties(**inputs) - _inputs = tokenizer.apply_chat_template(chat_params.messages, - tokenize=False) - _param = chat_params.dict(exclude_unset=True, - exclude={ - 'messages', 'model', 'logit_bias', - 'top_logprobs', 'n', 'user' - }) - _param["details"] = True - _param["output_formatter"] = "jsonlines_chat" if inputs.get( - "stream", False) else "json_chat" + _param = chat_params.model_dump(by_alias=True, exclude_unset=True) + _messages = _param.pop("messages") + _inputs = tokenizer.apply_chat_template(_messages, tokenize=False) + _param["details"] = True # Enable details for chat completions + _param[ + "output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat" return _inputs, _param diff --git a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py index 5e21b8506..9c2632739 100644 --- a/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py +++ b/engines/python/setup/djl_python/rolling_batch/trtllm_rolling_batch.py @@ -73,7 +73,8 @@ def translate_triton_params(self, parameters: dict) -> dict: parameters["temperature"] = parameters.get("temperature", 0.8) if "length_penalty" in parameters.keys(): parameters['len_penalty'] = parameters.pop('length_penalty') - parameters["streaming"] = parameters.get("streaming", True) + parameters["streaming"] = parameters.pop( + "stream", parameters.get("streaming", True)) return parameters @stop_on_any_exception diff --git a/engines/python/setup/djl_python/tests/test_properties_manager.py b/engines/python/setup/djl_python/tests/test_properties_manager.py index b5abca7f7..1167b8473 100644 --- a/engines/python/setup/djl_python/tests/test_properties_manager.py +++ b/engines/python/setup/djl_python/tests/test_properties_manager.py @@ -617,32 +617,34 @@ def test_chat_min_configs(): chat_configs = ChatProperties(**properties) self.assertEqual(chat_configs.messages, properties["messages"]) self.assertIsNone(chat_configs.model) - self.assertEqual(chat_configs.frequency_penalty, 0) + self.assertEqual(chat_configs.frequency_penalty, 0.0) self.assertIsNone(chat_configs.logit_bias) self.assertFalse(chat_configs.logprobs) self.assertIsNone(chat_configs.top_logprobs) - self.assertIsNone(chat_configs.max_new_tokens) + self.assertIsNone(chat_configs.max_tokens) self.assertEqual(chat_configs.n, 1) - self.assertEqual(chat_configs.presence_penalty, 0) + self.assertEqual(chat_configs.presence_penalty, 0.0) self.assertIsNone(chat_configs.seed) - self.assertIsNone(chat_configs.stop_sequences) - self.assertEqual(chat_configs.temperature, 1) - self.assertEqual(chat_configs.top_p, 1) + self.assertIsNone(chat_configs.stop) + self.assertFalse(chat_configs.stream) + self.assertEqual(chat_configs.temperature, 1.0) + self.assertEqual(chat_configs.top_p, 1.0) self.assertIsNone(chat_configs.user) def test_chat_all_configs(): properties["model"] = "model" - properties["frequency_penalty"] = "1" - properties["logit_bias"] = {"2435": -100, "640": -100} + properties["frequency_penalty"] = "1.0" + properties["logit_bias"] = {"2435": -100.0, "640": -100.0} properties["logprobs"] = "false" properties["top_logprobs"] = "3" properties["max_tokens"] = "256" properties["n"] = "1" - properties["presence_penalty"] = "1" + properties["presence_penalty"] = "1.0" properties["seed"] = "123" - properties["stop"] = "stop" - properties["temperature"] = "1" - properties["top_p"] = "3" + properties["stop"] = ["stop"] + properties["stream"] = "true" + properties["temperature"] = "1.0" + properties["top_p"] = "3.0" properties["user"] = "user" chat_configs = ChatProperties(**properties) @@ -652,18 +654,18 @@ def test_chat_all_configs(): float(properties['frequency_penalty'])) self.assertEqual(chat_configs.logit_bias, properties['logit_bias']) self.assertFalse(chat_configs.logprobs) - self.assertEqual(chat_configs.top_logprobs, - int(properties['top_logprobs'])) - self.assertEqual(chat_configs.max_new_tokens, + self.assertIsNone(chat_configs.top_logprobs) + self.assertEqual(chat_configs.max_tokens, int(properties['max_tokens'])) self.assertEqual(chat_configs.n, int(properties['n'])) self.assertEqual(chat_configs.presence_penalty, float(properties['presence_penalty'])) self.assertEqual(chat_configs.seed, int(properties['seed'])) - self.assertEqual(chat_configs.stop_sequences, properties['stop']) + self.assertEqual(chat_configs.stop, properties['stop']) + self.assertTrue(chat_configs.stream) self.assertEqual(chat_configs.temperature, - int(properties['temperature'])) - self.assertEqual(chat_configs.top_p, int(properties['top_p'])) + float(properties['temperature'])) + self.assertEqual(chat_configs.top_p, float(properties['top_p'])) self.assertEqual(chat_configs.user, properties['user']) def test_invalid_configs(): @@ -678,34 +680,44 @@ def test_invalid_configs(): ChatProperties(**test_properties) test_properties = dict(properties) - test_properties["frequency_penalty"] = "-3" + test_properties["frequency_penalty"] = "-3.0" + with self.assertRaises(ValueError): + ChatProperties(**test_properties) + test_properties["frequency_penalty"] = "3.0" + with self.assertRaises(ValueError): + ChatProperties(**test_properties) + + test_properties = dict(properties) + test_properties["logit_bias"] = {"2435": -100.0, "640": 200.0} with self.assertRaises(ValueError): ChatProperties(**test_properties) - test_properties["frequency_penalty"] = "3" + test_properties["logit_bias"] = {"2435": -200.0, "640": 100.0} with self.assertRaises(ValueError): ChatProperties(**test_properties) test_properties = dict(properties) + test_properties["logprobs"] = "true" test_properties["top_logprobs"] = "-1" with self.assertRaises(ValueError): ChatProperties(**test_properties) + test_properties["logprobs"] = "true" test_properties["top_logprobs"] = "30" with self.assertRaises(ValueError): ChatProperties(**test_properties) test_properties = dict(properties) - test_properties["presence_penalty"] = "-3" + test_properties["presence_penalty"] = "-3.0" with self.assertRaises(ValueError): ChatProperties(**test_properties) - test_properties["presence_penalty"] = "3" + test_properties["presence_penalty"] = "3.0" with self.assertRaises(ValueError): ChatProperties(**test_properties) test_properties = dict(properties) - test_properties["temperature"] = "-1" + test_properties["temperature"] = "-1.0" with self.assertRaises(ValueError): ChatProperties(**test_properties) - test_properties["temperature"] = "3" + test_properties["temperature"] = "3.0" with self.assertRaises(ValueError): ChatProperties(**test_properties)