diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 44eb1aec80..7a25901669 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -716,6 +716,79 @@ def messages2prompt(self, messages, sequence_start=True): return ret +@MODELS.register_module(name='yi') +class Yi(BaseModel): + """Chat template of Yi model.""" + + def __init__(self, + system='<|im_start|>system\n', + meta_instruction=None, + user='<|im_start|>user\n', + eoh='<|im_end|>\n', + eoa='<|im_end|>\n', + eosys='<|im_end|>\n', + assistant='<|im_start|>assistant\n', + stop_words=['<|im_end|>', '<|endoftext|>'], + **kwargs): + super().__init__(**kwargs) + self.system = system + self.meta_instruction = meta_instruction + self.user = user + self.eoh = eoh + self.eoa = eoa + self.eosys = eosys + self.assistant = assistant + self.stop_words = stop_words + + def decorate_prompt(self, prompt, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + prompt (str): user's input prompt + sequence_start (bool): indicator for the first round chat of a + session sequence + Returns: + str: the concatenated prompt + """ + assert self.capability == 'chat', \ + f'{type(self).__name__} has no capability of {self.capability}' + if sequence_start: + if self.meta_instruction is None: + return f'{self.user}{prompt}{self.eoh}' \ + f'{self.assistant}' + return f'{self.system}{self.meta_instruction}{self.eosys}' \ + f'{self.user}{prompt}{self.eoh}' \ + f'{self.assistant}' + else: + return f'{self.user}{prompt}{self.eoh}' \ + f'{self.assistant}' + + def messages2prompt(self, messages, sequence_start=True): + """Return the prompt that is concatenated with other elements in the + chat template. + + Args: + messages (str | List): user's input prompt + Returns: + str: the concatenated prompt + """ + + if isinstance(messages, str): + return self.get_prompt(messages, sequence_start) + eox_map = dict(user=self.eoh, assistant=self.eoa, system=self.eosys) + ret = '' + if self.meta_instruction: + ret += f'{self.system}:{self.meta_instruction}{self.eosys}' + + for message in messages: + role = message['role'] + content = message['content'] + ret += f'{eval(f"self.{role}")}{content}{eox_map[role]}' + ret += f'{self.assistant}' + return ret + + def main(model_name: str = 'test'): assert model_name in MODELS.module_dict.keys(), \ f"'{model_name}' is not supported. " \