Skip to content

Commit

Permalink
add chat template for Yi (#779)
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan authored Dec 4, 2023
1 parent 816022e commit 12dc3e1
Showing 1 changed file with 73 additions and 0 deletions.
73 changes: 73 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,79 @@ def messages2prompt(self, messages, sequence_start=True):
return ret


@MODELS.register_module(name='yi')
class Yi(BaseModel):
"""Chat template of Yi model."""

def __init__(self,
system='<|im_start|>system\n',
meta_instruction=None,
user='<|im_start|>user\n',
eoh='<|im_end|>\n',
eoa='<|im_end|>\n',
eosys='<|im_end|>\n',
assistant='<|im_start|>assistant\n',
stop_words=['<|im_end|>', '<|endoftext|>'],
**kwargs):
super().__init__(**kwargs)
self.system = system
self.meta_instruction = meta_instruction
self.user = user
self.eoh = eoh
self.eoa = eoa
self.eosys = eosys
self.assistant = assistant
self.stop_words = stop_words

def decorate_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
prompt (str): user's input prompt
sequence_start (bool): indicator for the first round chat of a
session sequence
Returns:
str: the concatenated prompt
"""
assert self.capability == 'chat', \
f'{type(self).__name__} has no capability of {self.capability}'
if sequence_start:
if self.meta_instruction is None:
return f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
return f'{self.system}{self.meta_instruction}{self.eosys}' \
f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'
else:
return f'{self.user}{prompt}{self.eoh}' \
f'{self.assistant}'

def messages2prompt(self, messages, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
chat template.
Args:
messages (str | List): user's input prompt
Returns:
str: the concatenated prompt
"""

if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys)
ret = ''
if self.meta_instruction:
ret += f'{self.system}:{self.meta_instruction}{self.eosys}'

for message in messages:
role = message['role']
content = message['content']
ret += f'{eval(f"self.{role}")}{content}{eox_map[role]}'
ret += f'{self.assistant}'
return ret


def main(model_name: str = 'test'):
assert model_name in MODELS.module_dict.keys(), \
f"'{model_name}' is not supported. " \
Expand Down

0 comments on commit 12dc3e1

Please sign in to comment.