Skip to content

Commit

Permalink
[python] Update chat properties (#1709)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Mar 31, 2024
1 parent 2b54a93 commit bbd9042
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 57 deletions.
61 changes: 39 additions & 22 deletions engines/python/setup/djl_python/chat_completions/chat_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# the specific language governing permissions and limitations under the License.
from typing import Optional, Union, List, Dict

from pydantic.v1 import BaseModel, Field, validator, root_validator
from pydantic import BaseModel, Field, validator, root_validator


class ChatProperties(BaseModel):
Expand All @@ -22,25 +22,28 @@ class ChatProperties(BaseModel):
"""

messages: List[Dict[str, str]]
model: Optional[str] # UNUSED
frequency_penalty: Optional[float] = 0
logit_bias: Optional[dict] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] # Currently only support 1
max_new_tokens: Optional[int] = Field(alias="max_tokens")
n: Optional[int] = 1 # Currently only support 1
presence_penalty: Optional[float] = 0
seed: Optional[int]
stop_sequences: Optional[Union[str, list]] = Field(alias="stop")
temperature: Optional[int] = 1
top_p: Optional[int] = 1
user: Optional[str]
model: Optional[str] = Field(default=None, exclude=True) # Unused
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = Field(default=None, exclude=True)
logprobs: Optional[bool] = Field(default=False, exclude=True)
top_logprobs: Optional[int] = Field(default=None,
serialization_alias="logprobs")
max_tokens: Optional[int] = Field(default=None,
serialization_alias="max_new_tokens")
n: Optional[int] = Field(default=1, exclude=True)
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = Field(default=None, exclude=True)

@validator('messages', pre=True)
def validate_messages(
cls, messages: List[Dict[str, str]]) -> List[Dict[str, str]]:
if messages is None:
return messages
return None

for message in messages:
if not ("role" in message and "content" in message):
Expand All @@ -52,17 +55,31 @@ def validate_messages(
@validator('frequency_penalty', pre=True)
def validate_frequency_penalty(cls, frequency_penalty: float) -> float:
if frequency_penalty is None:
return frequency_penalty
return None

frequency_penalty = float(frequency_penalty)
if frequency_penalty < -2.0 or frequency_penalty > 2.0:
raise ValueError("frequency_penalty must be between -2.0 and 2.0.")
return frequency_penalty

@validator('top_logprobs', pre=True)
def validate_top_logprobs(cls, top_logprobs: float) -> float:
@validator('logit_bias', pre=True)
def validate_logit_bias(cls, logit_bias: Dict[str, float]):
if logit_bias is None:
return None

for token_id, bias in logit_bias.items():
if bias < -100.0 or bias > 100.0:
raise ValueError(
"logit_bias value must be between -100 and 100.")
return logit_bias

@validator('top_logprobs')
def validate_top_logprobs(cls, top_logprobs: int, values):
if top_logprobs is None:
return top_logprobs
return None

if not values.get('logprobs'):
return None

top_logprobs = int(top_logprobs)
if top_logprobs < 0 or top_logprobs > 20:
Expand All @@ -72,7 +89,7 @@ def validate_top_logprobs(cls, top_logprobs: float) -> float:
@validator('presence_penalty', pre=True)
def validate_presence_penalty(cls, presence_penalty: float) -> float:
if presence_penalty is None:
return presence_penalty
return None

presence_penalty = float(presence_penalty)
if presence_penalty < -2.0 or presence_penalty > 2.0:
Expand All @@ -82,9 +99,9 @@ def validate_presence_penalty(cls, presence_penalty: float) -> float:
@validator('temperature', pre=True)
def validate_temperature(cls, temperature: float) -> float:
if temperature is None:
return temperature
return None

temperature = float(temperature)
if temperature < 0 or temperature > 2:
if temperature < 0.0 or temperature > 2.0:
raise ValueError("temperature must be between 0 and 2.")
return temperature
16 changes: 6 additions & 10 deletions engines/python/setup/djl_python/chat_completions/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,11 @@ def parse_chat_completions_request(inputs: map, is_rolling_batch: bool,
f"Cannot provide chat completion for tokenizer: {tokenizer.__class__}, "
f"please ensure that your tokenizer supports chat templates.")
chat_params = ChatProperties(**inputs)
_inputs = tokenizer.apply_chat_template(chat_params.messages,
tokenize=False)
_param = chat_params.dict(exclude_unset=True,
exclude={
'messages', 'model', 'logit_bias',
'top_logprobs', 'n', 'user'
})
_param["details"] = True
_param["output_formatter"] = "jsonlines_chat" if inputs.get(
"stream", False) else "json_chat"
_param = chat_params.model_dump(by_alias=True, exclude_unset=True)
_messages = _param.pop("messages")
_inputs = tokenizer.apply_chat_template(_messages, tokenize=False)
_param["details"] = True # Enable details for chat completions
_param[
"output_formatter"] = "jsonlines_chat" if chat_params.stream else "json_chat"

return _inputs, _param
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def translate_triton_params(self, parameters: dict) -> dict:
parameters["temperature"] = parameters.get("temperature", 0.8)
if "length_penalty" in parameters.keys():
parameters['len_penalty'] = parameters.pop('length_penalty')
parameters["streaming"] = parameters.get("streaming", True)
parameters["streaming"] = parameters.pop(
"stream", parameters.get("streaming", True))
return parameters

@stop_on_any_exception
Expand Down
60 changes: 36 additions & 24 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,32 +617,34 @@ def test_chat_min_configs():
chat_configs = ChatProperties(**properties)
self.assertEqual(chat_configs.messages, properties["messages"])
self.assertIsNone(chat_configs.model)
self.assertEqual(chat_configs.frequency_penalty, 0)
self.assertEqual(chat_configs.frequency_penalty, 0.0)
self.assertIsNone(chat_configs.logit_bias)
self.assertFalse(chat_configs.logprobs)
self.assertIsNone(chat_configs.top_logprobs)
self.assertIsNone(chat_configs.max_new_tokens)
self.assertIsNone(chat_configs.max_tokens)
self.assertEqual(chat_configs.n, 1)
self.assertEqual(chat_configs.presence_penalty, 0)
self.assertEqual(chat_configs.presence_penalty, 0.0)
self.assertIsNone(chat_configs.seed)
self.assertIsNone(chat_configs.stop_sequences)
self.assertEqual(chat_configs.temperature, 1)
self.assertEqual(chat_configs.top_p, 1)
self.assertIsNone(chat_configs.stop)
self.assertFalse(chat_configs.stream)
self.assertEqual(chat_configs.temperature, 1.0)
self.assertEqual(chat_configs.top_p, 1.0)
self.assertIsNone(chat_configs.user)

def test_chat_all_configs():
properties["model"] = "model"
properties["frequency_penalty"] = "1"
properties["logit_bias"] = {"2435": -100, "640": -100}
properties["frequency_penalty"] = "1.0"
properties["logit_bias"] = {"2435": -100.0, "640": -100.0}
properties["logprobs"] = "false"
properties["top_logprobs"] = "3"
properties["max_tokens"] = "256"
properties["n"] = "1"
properties["presence_penalty"] = "1"
properties["presence_penalty"] = "1.0"
properties["seed"] = "123"
properties["stop"] = "stop"
properties["temperature"] = "1"
properties["top_p"] = "3"
properties["stop"] = ["stop"]
properties["stream"] = "true"
properties["temperature"] = "1.0"
properties["top_p"] = "3.0"
properties["user"] = "user"

chat_configs = ChatProperties(**properties)
Expand All @@ -652,18 +654,18 @@ def test_chat_all_configs():
float(properties['frequency_penalty']))
self.assertEqual(chat_configs.logit_bias, properties['logit_bias'])
self.assertFalse(chat_configs.logprobs)
self.assertEqual(chat_configs.top_logprobs,
int(properties['top_logprobs']))
self.assertEqual(chat_configs.max_new_tokens,
self.assertIsNone(chat_configs.top_logprobs)
self.assertEqual(chat_configs.max_tokens,
int(properties['max_tokens']))
self.assertEqual(chat_configs.n, int(properties['n']))
self.assertEqual(chat_configs.presence_penalty,
float(properties['presence_penalty']))
self.assertEqual(chat_configs.seed, int(properties['seed']))
self.assertEqual(chat_configs.stop_sequences, properties['stop'])
self.assertEqual(chat_configs.stop, properties['stop'])
self.assertTrue(chat_configs.stream)
self.assertEqual(chat_configs.temperature,
int(properties['temperature']))
self.assertEqual(chat_configs.top_p, int(properties['top_p']))
float(properties['temperature']))
self.assertEqual(chat_configs.top_p, float(properties['top_p']))
self.assertEqual(chat_configs.user, properties['user'])

def test_invalid_configs():
Expand All @@ -678,34 +680,44 @@ def test_invalid_configs():
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["frequency_penalty"] = "-3"
test_properties["frequency_penalty"] = "-3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["frequency_penalty"] = "3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["logit_bias"] = {"2435": -100.0, "640": 200.0}
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["frequency_penalty"] = "3"
test_properties["logit_bias"] = {"2435": -200.0, "640": 100.0}
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["logprobs"] = "true"
test_properties["top_logprobs"] = "-1"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["logprobs"] = "true"
test_properties["top_logprobs"] = "30"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["presence_penalty"] = "-3"
test_properties["presence_penalty"] = "-3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["presence_penalty"] = "3"
test_properties["presence_penalty"] = "3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

test_properties = dict(properties)
test_properties["temperature"] = "-1"
test_properties["temperature"] = "-1.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)
test_properties["temperature"] = "3"
test_properties["temperature"] = "3.0"
with self.assertRaises(ValueError):
ChatProperties(**test_properties)

Expand Down

0 comments on commit bbd9042

Please sign in to comment.