Skip to content

Commit

Permalink
Add llama3 chat template
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Apr 19, 2024
1 parent 1f72b8f commit 6777757
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,44 @@ def match(cls, model_path: str) -> Optional[str]:
return 'llama2'


@MODELS.register_module(name='llama3')
class Llama3(BaseChatTemplate):
"""Chat template of LLaMA3 model."""

def __init__(self,
system='<|start_header_id|>system<|end_header_id|>\n\n',
meta_instruction=None,
eosys='<|eot_id|>',
assistant='<|start_header_id|>assistant<|end_header_id|>\n\n',
eoa='<|eot_id|>',
user='<|start_header_id|>user<|end_header_id|>\n\n',
eoh='<|eot_id|>',
stop_words=['<|eot_id|>', '<|end_of_text|>'],
session_len=8192,
**kwargs):
super().__init__(system=system,
meta_instruction=meta_instruction,
eosys=eosys,
assistant=assistant,
eoa=eoa,
user=user,
eoh=eoh,
stop_words=stop_words,
session_len=session_len,
**kwargs)

@classmethod
def match(cls, model_path: str) -> Optional[str]:
"""Return the model_name that was registered to MODELS.
Args:
model_path (str): the model path used for matching.
"""
if 'llama-3' in model_path.lower() or 'llama3' in model_path.lower():
if 'instruct' in model_path.lower():
return 'llama3'


@MODELS.register_module(name='qwen-14b')
@MODELS.register_module(name='qwen-7b')
@MODELS.register_module(name='qwen')
Expand Down

0 comments on commit 6777757

Please sign in to comment.