From 67777574052fc02606e3ebe294c1f11fce9a88ce Mon Sep 17 00:00:00 2001 From: AllentDan Date: Fri, 19 Apr 2024 11:15:18 +0800 Subject: [PATCH] Add llama3 chat template --- lmdeploy/model.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/lmdeploy/model.py b/lmdeploy/model.py index 786e9f8219..7485a3c28e 100644 --- a/lmdeploy/model.py +++ b/lmdeploy/model.py @@ -492,6 +492,44 @@ def match(cls, model_path: str) -> Optional[str]: return 'llama2' +@MODELS.register_module(name='llama3') +class Llama3(BaseChatTemplate): + """Chat template of LLaMA3 model.""" + + def __init__(self, + system='<|start_header_id|>system<|end_header_id|>\n\n', + meta_instruction=None, + eosys='<|eot_id|>', + assistant='<|start_header_id|>assistant<|end_header_id|>\n\n', + eoa='<|eot_id|>', + user='<|start_header_id|>user<|end_header_id|>\n\n', + eoh='<|eot_id|>', + stop_words=['<|eot_id|>', '<|end_of_text|>'], + session_len=8192, + **kwargs): + super().__init__(system=system, + meta_instruction=meta_instruction, + eosys=eosys, + assistant=assistant, + eoa=eoa, + user=user, + eoh=eoh, + stop_words=stop_words, + session_len=session_len, + **kwargs) + + @classmethod + def match(cls, model_path: str) -> Optional[str]: + """Return the model_name that was registered to MODELS. + + Args: + model_path (str): the model path used for matching. + """ + if 'llama-3' in model_path.lower() or 'llama3' in model_path.lower(): + if 'instruct' in model_path.lower(): + return 'llama3' + + @MODELS.register_module(name='qwen-14b') @MODELS.register_module(name='qwen-7b') @MODELS.register_module(name='qwen')