-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Close #343
- Loading branch information
Showing
12 changed files
with
671 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
""" | ||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py | ||
Conversation prompt templates. | ||
""" | ||
|
||
import dataclasses | ||
from enum import auto, IntEnum | ||
from typing import List, Any, Dict, Callable | ||
|
||
|
||
class SeparatorStyle(IntEnum): | ||
"""Separator styles.""" | ||
|
||
ADD_COLON_SINGLE = auto() | ||
ADD_COLON_TWO = auto() | ||
ADD_COLON_SPACE_SINGLE = auto() | ||
NO_COLON_SINGLE = auto() | ||
NO_COLON_TWO = auto() | ||
ADD_NEW_LINE_SINGLE = auto() | ||
LLAMA2 = auto() | ||
CHATGLM = auto() | ||
CHATML = auto() | ||
CHATINTERN = auto() | ||
DOLLY = auto() | ||
RWKV = auto() | ||
PHOENIX = auto() | ||
ROBIN = auto() | ||
|
||
|
||
@dataclasses.dataclass | ||
class Conversation: | ||
"""A class that manages prompt templates and keeps all conversation history.""" | ||
|
||
# The name of this template | ||
name: str | ||
# The system prompt | ||
system: str | ||
# Two roles | ||
roles: List[str] | ||
# All messages. Each item is (role, message). | ||
messages: List[List[str]] | ||
# The number of few shot examples | ||
offset: int | ||
# Separators | ||
sep_style: SeparatorStyle | ||
sep: str | ||
sep2: str = None | ||
# Stop criteria (the default one is EOS token) | ||
stop_str: str = None | ||
# Stops generation if meeting any token in this list | ||
stop_token_ids: List[int] = None | ||
|
||
# format system message | ||
system_formatter: Callable = None | ||
|
||
def get_prompt(self) -> str: | ||
"""Get the prompt for generation.""" | ||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: | ||
ret = self.system + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + ": " + message + self.sep | ||
else: | ||
ret += role + ":" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: | ||
seps = [self.sep, self.sep2] | ||
ret = self.system + seps[0] | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
ret += role + ": " + message + seps[i % 2] | ||
else: | ||
ret += role + ":" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: | ||
ret = self.system + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + ": " + message + self.sep | ||
else: | ||
ret += role + ": " # must be end with a space | ||
return ret | ||
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: | ||
ret = "" if self.system == "" else self.system + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + "\n" + message + self.sep | ||
else: | ||
ret += role + "\n" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: | ||
ret = self.system | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + message + self.sep | ||
else: | ||
ret += role | ||
return ret | ||
elif self.sep_style == SeparatorStyle.NO_COLON_TWO: | ||
seps = [self.sep, self.sep2] | ||
ret = self.system | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
ret += role + message + seps[i % 2] | ||
else: | ||
ret += role | ||
return ret | ||
elif self.sep_style == SeparatorStyle.RWKV: | ||
ret = self.system | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
ret += ( | ||
role | ||
+ ": " | ||
+ message.replace("\r\n", "\n").replace("\n\n", "\n") | ||
) | ||
ret += "\n\n" | ||
else: | ||
ret += role + ":" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.LLAMA2: | ||
seps = [self.sep, self.sep2] | ||
ret = "" | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
if i == 0: | ||
ret += self.system + message | ||
else: | ||
ret += role + " " + message + seps[i % 2] | ||
else: | ||
ret += role | ||
return ret | ||
elif self.sep_style == SeparatorStyle.CHATGLM: | ||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 | ||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 | ||
round_add_n = 1 if self.name == "chatglm2" else 0 | ||
if self.system: | ||
ret = self.system + self.sep | ||
else: | ||
ret = "" | ||
|
||
for i, (role, message) in enumerate(self.messages): | ||
if i % 2 == 0: | ||
ret += f"[Round {i//2 + round_add_n}]{self.sep}" | ||
|
||
if message: | ||
ret += f"{role}:{message}{self.sep}" | ||
else: | ||
ret += f"{role}:" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.CHATML: | ||
ret = "" if self.system == "" else self.system + self.sep + "\n" | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + "\n" + message + self.sep + "\n" | ||
else: | ||
ret += role + "\n" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.CHATINTERN: | ||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 | ||
seps = [self.sep, self.sep2] | ||
ret = self.system | ||
for i, (role, message) in enumerate(self.messages): | ||
if i % 2 == 0: | ||
ret += "<s>" | ||
if message: | ||
ret += role + ":" + message + seps[i % 2] + "\n" | ||
else: | ||
ret += role + ":" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.DOLLY: | ||
seps = [self.sep, self.sep2] | ||
ret = self.system | ||
for i, (role, message) in enumerate(self.messages): | ||
if message: | ||
ret += role + ":\n" + message + seps[i % 2] | ||
if i % 2 == 1: | ||
ret += "\n\n" | ||
else: | ||
ret += role + ":\n" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.PHOENIX: | ||
ret = self.system | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + ": " + "<s>" + message + "</s>" | ||
else: | ||
ret += role + ": " + "<s>" | ||
return ret | ||
elif self.sep_style == SeparatorStyle.ROBIN: | ||
ret = self.system + self.sep | ||
for role, message in self.messages: | ||
if message: | ||
ret += role + ":\n" + message + self.sep | ||
else: | ||
ret += role + ":\n" | ||
return ret | ||
else: | ||
raise ValueError(f"Invalid style: {self.sep_style}") | ||
|
||
def append_message(self, role: str, message: str): | ||
"""Append a new message.""" | ||
self.messages.append([role, message]) | ||
|
||
def update_last_message(self, message: str): | ||
"""Update the last output. | ||
The last message is typically set to be None when constructing the prompt, | ||
so we need to update it in-place after getting the response from a model. | ||
""" | ||
self.messages[-1][1] = message | ||
|
||
def update_system_message(self, system_message: str): | ||
"""Update system message""" | ||
if self.system_formatter: | ||
self.system = self.system_formatter(system_message) | ||
else: | ||
self.system = system_message | ||
|
||
def to_gradio_chatbot(self): | ||
"""Convert the conversation to gradio chatbot format.""" | ||
ret = [] | ||
for i, (role, msg) in enumerate(self.messages[self.offset :]): | ||
if i % 2 == 0: | ||
ret.append([msg, None]) | ||
else: | ||
ret[-1][-1] = msg | ||
return ret | ||
|
||
def to_openai_api_messages(self): | ||
"""Convert the conversation to OpenAI chat completion format.""" | ||
ret = [{"role": "system", "content": self.system}] | ||
|
||
for i, (_, msg) in enumerate(self.messages[self.offset :]): | ||
if i % 2 == 0: | ||
ret.append({"role": "user", "content": msg}) | ||
else: | ||
if msg is not None: | ||
ret.append({"role": "assistant", "content": msg}) | ||
return ret | ||
|
||
def copy(self): | ||
return Conversation( | ||
name=self.name, | ||
system=self.system, | ||
roles=self.roles, | ||
messages=[[x, y] for x, y in self.messages], | ||
offset=self.offset, | ||
sep_style=self.sep_style, | ||
sep=self.sep, | ||
sep2=self.sep2, | ||
stop_str=self.stop_str, | ||
stop_token_ids=self.stop_token_ids, | ||
system_formatter=self.system_formatter, | ||
) | ||
|
||
def dict(self): | ||
return { | ||
"template_name": self.name, | ||
"system": self.system, | ||
"roles": self.roles, | ||
"messages": self.messages, | ||
"offset": self.offset, | ||
} | ||
|
||
|
||
# A global registry for all conversation templates | ||
conv_templates: Dict[str, Conversation] = {} | ||
|
||
|
||
def register_conv_template(template: Conversation, override: bool = False): | ||
"""Register a new conversation template.""" | ||
if not override: | ||
assert ( | ||
template.name not in conv_templates | ||
), f"{template.name} has been registered." | ||
|
||
conv_templates[template.name] = template | ||
|
||
|
||
def get_conv_template(name: str) -> Conversation: | ||
"""Get a conversation template.""" | ||
return conv_templates[name].copy() | ||
|
||
|
||
# llama2 template | ||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 | ||
register_conv_template( | ||
Conversation( | ||
name="llama-2", | ||
system="<s>[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. " | ||
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. " | ||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n" | ||
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. " | ||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n", | ||
roles=("[INST]", "[/INST]"), | ||
messages=(), | ||
offset=0, | ||
sep_style=SeparatorStyle.LLAMA2, | ||
sep=" ", | ||
sep2=" </s><s>", | ||
stop_token_ids=[2], | ||
system_formatter=lambda msg: f"<s>[INST] <<SYS>>\n{msg}\n<</SYS>>\n\n", | ||
) | ||
) | ||
|
||
# TODO Support other model conversation template |
Oops, something went wrong.