Skip to content

Commit

Permalink
feat: Support llama-2 model (#347)
Browse files Browse the repository at this point in the history
Close #343
  • Loading branch information
csunny authored Jul 20, 2023
2 parents 412b104 + 168c754 commit b2fb374
Show file tree
Hide file tree
Showing 12 changed files with 671 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/modules/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ cp .env.template .env
LLM_MODEL=vicuna-13b
MODEL_SERVER=http://127.0.0.1:8000
```
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b.
now we support models vicuna-13b, vicuna-7b, chatglm-6b, flan-t5-base, guanaco-33b-merged, falcon-40b, gorilla-7b, llama-2-7b, llama-2-13b.

if you want use other model, such as chatglm-6b, you just need update .env config file.
```
Expand Down
3 changes: 3 additions & 0 deletions pilot/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
"gorilla-7b": os.path.join(MODEL_PATH, "gorilla-7b"),
"gptj-6b": os.path.join(MODEL_PATH, "ggml-gpt4all-j-v1.3-groovy.bin"),
"proxyllm": "proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
}

# Load model config
Expand Down
14 changes: 14 additions & 0 deletions pilot/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,26 @@ def loader(self, model_path: str, from_pretrained_kwargs: dict):
return "proxyllm", None


class Llama2Adapter(BaseLLMAdaper):
"""The model adapter for llama-2"""

def match(self, model_path: str):
return "llama-2" in model_path.lower()

def loader(self, model_path: str, from_pretrained_kwargs: dict):
model, tokenizer = super().loader(model_path, from_pretrained_kwargs)
model.config.eos_token_id = tokenizer.eos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer


register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
register_llm_model_adapters(FalconAdapater)
register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
# TODO Default support vicuna, other model need to tests and Evaluate

# just for test_py, remove this later
Expand Down
308 changes: 308 additions & 0 deletions pilot/model/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
"""
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Conversation prompt templates.
"""

import dataclasses
from enum import auto, IntEnum
from typing import List, Any, Dict, Callable


class SeparatorStyle(IntEnum):
"""Separator styles."""

ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
CHATGLM = auto()
CHATML = auto()
CHATINTERN = auto()
DOLLY = auto()
RWKV = auto()
PHOENIX = auto()
ROBIN = auto()


@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""

# The name of this template
name: str
# The system prompt
system: str
# Two roles
roles: List[str]
# All messages. Each item is (role, message).
messages: List[List[str]]
# The number of few shot examples
offset: int
# Separators
sep_style: SeparatorStyle
sep: str
sep2: str = None
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] = None

# format system message
system_formatter: Callable = None

def get_prompt(self) -> str:
"""Get the prompt for generation."""
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ": " # must be end with a space
return ret
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
ret = "" if self.system == "" else self.system + self.sep
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
ret = self.system
for role, message in self.messages:
if message:
ret += role + message + self.sep
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + message + seps[i % 2]
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.RWKV:
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += (
role
+ ": "
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
)
ret += "\n\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.LLAMA2:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if i == 0:
ret += self.system + message
else:
ret += role + " " + message + seps[i % 2]
else:
ret += role
return ret
elif self.sep_style == SeparatorStyle.CHATGLM:
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
round_add_n = 1 if self.name == "chatglm2" else 0
if self.system:
ret = self.system + self.sep
else:
ret = ""

for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += f"[Round {i//2 + round_add_n}]{self.sep}"

if message:
ret += f"{role}{message}{self.sep}"
else:
ret += f"{role}:"
return ret
elif self.sep_style == SeparatorStyle.CHATML:
ret = "" if self.system == "" else self.system + self.sep + "\n"
for role, message in self.messages:
if message:
ret += role + "\n" + message + self.sep + "\n"
else:
ret += role + "\n"
return ret
elif self.sep_style == SeparatorStyle.CHATINTERN:
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if i % 2 == 0:
ret += "<s>"
if message:
ret += role + ":" + message + seps[i % 2] + "\n"
else:
ret += role + ":"
return ret
elif self.sep_style == SeparatorStyle.DOLLY:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
if i % 2 == 1:
ret += "\n\n"
else:
ret += role + ":\n"
return ret
elif self.sep_style == SeparatorStyle.PHOENIX:
ret = self.system
for role, message in self.messages:
if message:
ret += role + ": " + "<s>" + message + "</s>"
else:
ret += role + ": " + "<s>"
return ret
elif self.sep_style == SeparatorStyle.ROBIN:
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ":\n" + message + self.sep
else:
ret += role + ":\n"
return ret
else:
raise ValueError(f"Invalid style: {self.sep_style}")

def append_message(self, role: str, message: str):
"""Append a new message."""
self.messages.append([role, message])

def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message

def update_system_message(self, system_message: str):
"""Update system message"""
if self.system_formatter:
self.system = self.system_formatter(system_message)
else:
self.system = system_message

def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
ret = [{"role": "system", "content": self.system}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret

def copy(self):
return Conversation(
name=self.name,
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
system_formatter=self.system_formatter,
)

def dict(self):
return {
"template_name": self.name,
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
}


# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert (
template.name not in conv_templates
), f"{template.name} has been registered."

conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()


# llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template(
Conversation(
name="llama-2",
system="<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
roles=("[INST]", "[/INST]"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2=" </s><s>",
stop_token_ids=[2],
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n",
)
)

# TODO Support other model conversation template
Loading

0 comments on commit b2fb374

Please sign in to comment.