diff --git a/paddlenlp/datasets/formatter.py b/paddlenlp/datasets/formatter.py
new file mode 100644
index 000000000000..9c9a2ec49401
--- /dev/null
+++ b/paddlenlp/datasets/formatter.py
@@ -0,0 +1,195 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import re
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum, unique
+from typing import Optional, Union
+
+from typing_extensions import override
+
+SLOTS = list[Union[str, set[str], dict[str, str]]]
+
+
+KG_RES_MARKUPS = [
+ "[]",
+ "[]",
+ "[]",
+ "[]",
+ "[]",
+ "[]",
+ "[]",
+ "[]",
+]
+
+
+@unique
+class Role(str, Enum):
+ USER = "user"
+ ASSISTANT = "assistant"
+ SYSTEM = "system"
+
+
+def extract_knowledge(text):
+ """Extracts structured knowledge from text markup.
+
+ Args:
+ text (str): Input text containing markup.
+
+ Returns:
+ str: Processed knowledge string.
+
+ Raises:
+ ValueError: If no valid knowledge pattern found.
+ """
+
+ if any(markup in text for markup in KG_RES_MARKUPS):
+ for markup in KG_RES_MARKUPS + ["[]", "[]"]:
+ text = text.replace(markup, "")
+ text = f"知识库:{text.strip()}\n根据所提供的知识库信息,回答问题并补全对话:"
+ return text
+
+ res = re.findall(
+ r"\[\](.*?)\[<\/search-res>\]",
+ text,
+ re.DOTALL | re.MULTILINE,
+ )
+ if len(res) > 0:
+ text = res[0]
+ text = f"{text.strip()}\n根据以上参考文章回答问题,补全对话"
+ return text
+
+ res = re.findall(
+ r"\[\](.*?)\[<\/prompt-res>\]",
+ text,
+ re.DOTALL | re.MULTILINE,
+ )
+ if len(res) > 0:
+ text = res[0]
+ text = text.strip()
+ return text
+
+ res = re.findall(
+ r"\[\](.*?)\[<\/compute-res>\]",
+ text,
+ re.DOTALL | re.MULTILINE,
+ )
+ if len(res) > 0:
+ text = res[0]
+ text = f"参考文章1:{text.strip()}\n根据以上参考文章回答问题,补全对话"
+ return text
+
+ res = re.findall(
+ r"\[\](.*?)\[<\/citation-ref>\]",
+ text,
+ re.DOTALL | re.MULTILINE,
+ )
+ if len(res) > 0:
+ text = res[0]
+ text = (
+ "请参考搜索结果回答下面问题并使用引用标记来标注回答内容参考的搜索结果序号,"
+ "例如^[1]^ (引用单个搜索结果),^[1][2]^(引用多个搜索结果),"
+ "其中方括号中的数字是搜索结果序号。引用标记只能出现在句尾标点符号前。\n"
+ "以下是搜索结果(每行开头[1]、[2]、...是搜索结果序号),"
+ f"可以对答案中的核心部分进行markdown加粗(**加粗内容**):\n{text.strip()}\n"
+ "根据以上搜索结果回答问题并标注引用,补全对话"
+ )
+ return text
+
+ res = re.findall(
+ r"\[\](.*?)\[<\/retrieve-ref>\]",
+ text,
+ re.DOTALL | re.MULTILINE,
+ )
+ if len(res) > 0:
+ text = res[0]
+ text = (
+ "请你扮演一个专家,参考搜索结果中正确、可信、高质量的信息回答问题,并注明答案中引用的搜索结果,"
+ "格式为^[2]^表示引用了第2条搜索结果,^[1][3]^表示引用第1和第3条搜索结果。"
+ "每条搜索结果包含若干相关内容片段。同时你需要遵循以下原则回答问题:\n"
+ "1. 严格遵循搜索结果作答,可以承认不知道答案,并尝试给出一些搜索结果中的相关背景信息。\n"
+ "2. 如果搜索结果存在多种可能的答案,要罗列出每种情况。\n"
+ "3. 如果问题涉及金融、医疗、法律等存在风险的领域,请在结尾提醒用户注意并进行免责说明。\n"
+ f"搜索结果:\n{text.strip()}\n\n现在,请根据上面的搜索结果回答问题并标注引用,补全对话"
+ )
+ return text
+
+ raise ValueError(f"Cannot extract knowledge from `{text}`")
+
+
+@dataclass
+class Formatter(ABC):
+ slots: SLOTS = field(default_factory=list)
+ tool_format: Optional[str] = None
+
+ @abstractmethod
+ def apply(self, **kwargs) -> SLOTS:
+ r"""Forms a list of slots according to the inputs to encode."""
+ ...
+
+
+@dataclass
+class EmptyFormatter(Formatter):
+ def __post_init__(self):
+ has_placeholder = False
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
+ has_placeholder = True
+
+ if has_placeholder:
+ raise ValueError("Empty formatter should not contain any placeholder.")
+
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ return self.slots
+
+
+@dataclass
+class StringFormatter(Formatter):
+ def __post_init__(self):
+ has_placeholder = False
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
+ has_placeholder = True
+
+ if not has_placeholder:
+ raise ValueError("A placeholder is required in the string formatter.")
+
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ elements = []
+ for slot in self.slots:
+ if isinstance(slot, str):
+ for name, value in kwargs.items():
+ if not isinstance(value, str):
+ raise RuntimeError(f"Expected a string, got {name} : s{value}")
+
+ slot = slot.replace("{{" + name + "}}", value, 1)
+ elements.append(slot)
+ elif isinstance(slot, (dict, set)):
+ elements.append(slot)
+ else:
+ raise RuntimeError(f"Input must be string, set[str] or dict[str, str], got {type(slot)}.")
+
+ return elements
+
+
+@dataclass
+class KnowledgeFormatter(StringFormatter):
+ @override
+ def apply(self, **kwargs) -> SLOTS:
+ content: str = extract_knowledge(kwargs.pop("content")) + "\n"
+ idx: int = kwargs.pop("idx")
+ return super().apply(content=content, idx=idx)
diff --git a/paddlenlp/datasets/template.py b/paddlenlp/datasets/template.py
new file mode 100644
index 000000000000..e3836b553e63
--- /dev/null
+++ b/paddlenlp/datasets/template.py
@@ -0,0 +1,543 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import re
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Optional
+
+from .formatter import EmptyFormatter, KnowledgeFormatter, Role, StringFormatter
+
+if TYPE_CHECKING:
+ from paddlenlp.transformers import PretrainedTokenizer
+
+ from .formatter import SLOTS, Formatter
+
+
+logger = logging.getLogger(__name__)
+
+
+def contain_tokens(text, token_list):
+ """Checks if any token in list exist in the text.
+
+ Args:
+ text (List[str]): Input text sequences to check.
+ token_list (List[str]): tokens to search for.
+
+ Returns:
+ bool: True if any is found, False otherwise.
+ """
+
+ for sp_token in token_list:
+ for x in text:
+ if sp_token in x:
+ return True
+ return False
+
+
+@dataclass
+class Template:
+ format_user: "Formatter"
+ format_assistant: "Formatter"
+ format_system: "Formatter"
+ format_knowledge: "Formatter"
+ format_prefix: "Formatter"
+ default_system: str
+ stop_words: list[str]
+ thought_words: tuple[str, str]
+ efficient_eos: bool
+ replace_eos: bool
+ replace_jinja_template: bool
+
+ def encode_oneturn(
+ self,
+ tokenizer: "PretrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ enable_thinking: bool = False,
+ ) -> tuple[list[int], list[int]]:
+ r"""Return a single pair of token ids representing prompt and response respectively."""
+ system = None
+ if messages[0]["role"] == Role.SYSTEM.value:
+ system = messages[0]["content"]
+ messages = messages[1:]
+ encoded_messages = self._encode(tokenizer, messages, system)
+ prompt_ids = []
+ for encoded_ids in encoded_messages[:-1]:
+ prompt_ids += encoded_ids
+
+ response_ids = encoded_messages[-1]
+ return prompt_ids, response_ids
+
+ def encode_multiturn(
+ self,
+ tokenizer: "PretrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str] = None,
+ ) -> list[tuple[list[int], list[int]]]:
+ r"""Return multiple pairs of token ids representing prompts and responses respectively."""
+ system = None
+ if messages[0]["role"] == Role.SYSTEM.value:
+ system = messages[0]["content"]
+ messages = messages[1:]
+ encoded_messages = self._encode(tokenizer, messages, system)
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
+
+ def get_stop_token_ids(self, tokenizer: "PretrainedTokenizer") -> list[int]:
+ r"""Return stop token ids."""
+ stop_token_ids = {tokenizer.eos_token_id}
+ for token in self.stop_words:
+ stop_token_ids.add(tokenizer.convert_tokens_to_ids(token))
+
+ return list(stop_token_ids)
+
+ def add_thought(self, content: str) -> str:
+ r"""Add empty thought to assistant message."""
+ return f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n" + content
+
+ def remove_thought(self, content: str) -> str:
+ r"""Remove thought from assistant message."""
+ pattern = re.compile(f"{re.escape(self.thought_words[0])}(.*?){re.escape(self.thought_words[1])}", re.DOTALL)
+ return re.sub(pattern, "", content).lstrip("\n")
+
+ def get_thought_word_ids(self, tokenizer: "PretrainedTokenizer") -> list[int]:
+ r"""Get the token ids of thought words."""
+ return tokenizer.encode(f"{self.thought_words[0]}\n\n{self.thought_words[1]}\n\n", add_special_tokens=False)
+
+ def _convert_elements_to_ids(self, tokenizer: "PretrainedTokenizer", elements: "SLOTS") -> list[int]:
+ r"""Convert elements to token ids."""
+ token_ids = []
+ for elem in elements:
+ if isinstance(elem, str):
+ if len(elem) != 0:
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)["input_ids"]
+ elif isinstance(elem, dict):
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
+ elif isinstance(elem, set):
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
+ token_ids += [tokenizer.bos_token_id]
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
+ token_ids += [tokenizer.eos_token_id]
+ else:
+ raise ValueError(f"Input must be string, set[str] or dict[str, str], got {type(elem)}")
+
+ return token_ids
+
+ def _encode(
+ self,
+ tokenizer: "PretrainedTokenizer",
+ messages: list[dict[str, str]],
+ system: Optional[str],
+ ) -> list[list[int]]:
+ r"""Encode formatted inputs to pairs of token ids.
+
+ Turn 0: prefix + system + query resp
+ Turn t: query resp.
+ """
+ system = system or self.default_system
+ encoded_messages = []
+ for i, message in enumerate(messages):
+ elements = []
+
+ if i == 0:
+ elements += self.format_prefix.apply()
+ if system:
+ elements += self.format_system.apply(content=(system))
+
+ if message["role"] == Role.USER:
+ if (
+ self.format_knowledge
+ and hasattr(tokenizer, "markup_tokens")
+ and i == len(messages) - 2
+ and contain_tokens([message["content"]], tokenizer.markup_tokens)
+ ):
+ elements += self.format_knowledge.apply(content=message["content"], idx=str(i // 2))
+ else:
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
+ elif message["role"] == Role.ASSISTANT:
+ elements += self.format_assistant.apply(content=message["content"])
+ else:
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
+
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
+
+ return encoded_messages
+
+ @staticmethod
+ def _add_or_replace_eos_token(tokenizer: "PretrainedTokenizer", eos_token: str) -> None:
+ r"""Add or replace eos token to the tokenizer."""
+ if tokenizer.eos_token == eos_token:
+ return
+
+ is_added = tokenizer.eos_token_id is None
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
+
+ if is_added:
+ logger.info(f"Add eos token: {tokenizer.eos_token}.")
+ else:
+ logger.info(f"Replace eos token: {tokenizer.eos_token}.")
+
+ if num_added_tokens > 0:
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
+
+ def fix_special_tokens(self, tokenizer: "PretrainedTokenizer") -> None:
+ r"""Add eos token and pad token to the tokenizer."""
+ stop_words = self.stop_words
+ if self.replace_eos:
+ if not stop_words:
+ raise ValueError("Stop words are required to replace the EOS token.")
+
+ self._add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
+ stop_words = stop_words[1:]
+
+ if tokenizer.eos_token_id is None:
+ self._add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
+
+ if tokenizer.pad_token_id is None:
+ tokenizer.pad_token = tokenizer.eos_token
+ logger.info(f"Add pad token: {tokenizer.pad_token}")
+
+ if stop_words:
+ num_added_tokens = tokenizer.add_special_tokens(
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
+ )
+ logger.info("Add {} to stop words.".format(",".join(stop_words)))
+ if num_added_tokens > 0:
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
+
+ @staticmethod
+ def _jinja_escape(content: str) -> str:
+ r"""Escape single quotes in content."""
+ return content.replace("'", r"\'")
+
+ @staticmethod
+ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PretrainedTokenizer", placeholder: str = "content") -> str:
+ r"""Convert slots to jinja template."""
+ slot_items = []
+ for slot in slots:
+ if isinstance(slot, str):
+ slot_pieces = slot.split("{{content}}")
+ if slot_pieces[0]:
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[0]) + "'")
+ if len(slot_pieces) > 1:
+ slot_items.append(placeholder)
+ if slot_pieces[1]:
+ slot_items.append("'" + Template._jinja_escape(slot_pieces[1]) + "'")
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
+ slot_items.append("'" + tokenizer.bos_token + "'")
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
+ slot_items.append("'" + tokenizer.eos_token + "'")
+ elif isinstance(slot, dict):
+ slot_items.append("'" + slot.get("token") + "'")
+ # raise ValueError("Dict is not supported.")
+
+ return " + ".join(slot_items)
+
+ def _get_jinja_template(self, tokenizer: "PretrainedTokenizer") -> str:
+ r"""Return the jinja template."""
+ prefix = self._convert_slots_to_jinja(self.format_prefix.apply(), tokenizer)
+ system = self._convert_slots_to_jinja(self.format_system.apply(), tokenizer, placeholder="system_message")
+ user = self._convert_slots_to_jinja(self.format_user.apply(), tokenizer)
+ assistant = self._convert_slots_to_jinja(self.format_assistant.apply(), tokenizer)
+ jinja_template = ""
+ if prefix:
+ jinja_template += "{{ " + prefix + " }}"
+
+ if self.default_system:
+ jinja_template += "{% set system_message = '" + self._jinja_escape(self.default_system) + "' %}"
+
+ if not self.format_knowledge:
+ jinja_template += (
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% endif %}"
+ "{% if system_message is defined %}{{ " + system + " }}{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% set content = message['content'] %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ " + user + " }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ " + assistant + " }}"
+ "{% endif %}"
+ "{% endfor %}"
+ )
+ else:
+ jinja_template += (
+ "{% set KG_RES_MARKUPS = ['[]', '[]', '[]', '[]'] %}{{'<|begin_of_sentence|>'}}"
+ "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}"
+ "{% else %}{% set loop_messages = messages %}{% endif %}{% if system_message is defined %}{{ system_message + '\n' }}"
+ "{% endif %}{% set ns = namespace(last_user_message=None) %}{% for message in loop_messages %}{% if message['role'] == 'user' %}"
+ "{% set ns.last_user_message = message['content'] %}{% endif %}{% endfor %}{% for message in loop_messages %}"
+ "{% set content = message['content'] %}{% if message['role'] == 'user' %}{% if content == ns.last_user_message %}{% set text = content %}"
+ "{% set ns = namespace(has_markup=False) %}{% for markup in KG_RES_MARKUPS + ['[]', '[]'] %}{% if markup in text %}"
+ "{% set ns.has_markup = True %}{% set text = text.replace(markup, '') %}{% endif %}{% endfor %}{% if ns.has_markup == True %}"
+ "{{ 'User: 知识库:' + text.strip() + '\n根据所提供的知识库信息,回答问题并补全对话:\nAssistant: ' }}{% else %}"
+ "{% set res = text | regex_findall('\[\](.*?)\[\]', multiline=True, dotall=True) %}{% if res %}"
+ "{{ 'User: ' + res[0].strip() + '\n根据以上参考文章回答问题,补全对话\nAssistant: ' }}{% else %}"
+ "{% set res = text | regex_findall('\[\](.*?)\[\]', multiline=True, dotall=True) %}{% if res %}"
+ "{{ 'User: ' + res[0].strip() + '\nAssistant: ' }}{% else %}"
+ "{% set res = text | regex_findall('\[\](.*?)\[\]', multiline=True, dotall=True) %}{% if res %}"
+ "{{ 'User: 参考文章1:' + res[0].strip() + '\n根据以上参考文章回答问题,补全对话\nAssistant: ' }}{% else %}"
+ "{% set res = text | regex_findall('\[\](.*?)\[\]', multiline=True, dotall=True) %}"
+ "{% if res %} User: 请参考搜索结果回答下面问题并使用引用标记来标注回答内容参考的搜索结果序号,例如^[1]^ (引用单个搜索结果),^[1][2]^(引用多个搜索结果),其中方括号中的数字是搜索结果序号。引用标记只能出现在句尾标点符号前。 以下是搜索结果(每行开头[1]、[2]、...是搜索结果序号),可以对答案中的核心部分进行markdown加粗(加粗内容): {{ res[0].strip() }} 根据以上搜索结果回答问题并标注引用,补全对话 Assistant: {% else %}"
+ "{% set res = text | regex_findall('\[\](.*?)\[\]', multiline=True, dotall=True) %}"
+ "{% if res %} User: 请你扮演一个专家,参考搜索结果中正确、可信、高质量的信息回答问题,并注明答案中引用的搜索结果,格式为^[2]^表示引用了第2条搜索结果,^[1][3]^表示引用第1和第3条搜索结果。每条搜索结果包含若干相关内容片段。同时你需要遵循以下原则回答问题: 1. 严格遵循搜索结果作答,可以承认不知道答案,并尝试给出一些搜索结果中的相关背景信息。 2. 如果搜索结果存在多种可能的答案,要罗列出每种情况。 3. 如果问题涉及金融、医疗、法律等存在风险的领域,请在结尾提醒用户注意并进行免责说明。 搜索结果: {{ res[0].strip() }} 现在,请根据上面的搜索结果回答问题并标注引用,补全对话 Assistant: {% else %}"
+ "{{ 'User: ' + content + '\nAssistant: ' }}{% endif %}{% endif %}{% endif %}{% endif %}{% endif %}{% endif %}"
+ "{% else %}{{ 'User: ' + content + '\nAssistant: ' }}{% endif %}"
+ "{% elif message['role'] == 'assistant' %}{{ content + '<|end_of_sentence|>' }}{% endif %}{% endfor %}"
+ )
+ return jinja_template
+
+ def fix_jinja_template(self, tokenizer: "PretrainedTokenizer") -> None:
+ r"""Replace the jinja template in the tokenizer."""
+ if tokenizer.chat_template is None or self.replace_jinja_template:
+ try:
+ tokenizer.chat_template = self._get_jinja_template(tokenizer)
+ except ValueError as e:
+ logger.info(f"Cannot add this chat template to tokenizer: {e}.")
+
+
+TEMPLATES: dict[str, "Template"] = {}
+
+
+def register_template(
+ name: str,
+ format_user: Optional["Formatter"] = None,
+ format_assistant: Optional["Formatter"] = None,
+ format_system: Optional["Formatter"] = None,
+ format_knowledge: Optional["Formatter"] = None,
+ format_prefix: Optional["Formatter"] = None,
+ default_system: str = "",
+ stop_words: Optional[list[str]] = None,
+ thought_words: Optional[tuple[str, str]] = None,
+ efficient_eos: bool = False,
+ replace_eos: bool = False,
+ replace_jinja_template: bool = False,
+ template_class: type["Template"] = Template,
+) -> None:
+ r"""Register a chat template.
+
+ To add the following chat template:
+ ```
+ user prompt here
+ model response here
+ user prompt here
+ model response here
+ ```
+
+ The corresponding code should be:
+ ```
+ register_template(
+ name="custom",
+ format_user=StringFormatter(slots=["{{content}}\n"]),
+ format_assistant=StringFormatter(slots=["{{content}}\n"]),
+ format_prefix=EmptyFormatter(""),
+ )
+ ```
+ """
+ if name in TEMPLATES:
+ raise ValueError(f"Template {name} already exists.")
+
+ default_slots = ["{{content}}"] if efficient_eos else ["{{content}}", {"eos_token"}]
+ default_user_formatter = StringFormatter(slots=["{{content}}"])
+ default_assistant_formatter = StringFormatter(slots=default_slots)
+ default_prefix_formatter = EmptyFormatter()
+ TEMPLATES[name] = template_class(
+ format_user=format_user or default_user_formatter,
+ format_assistant=format_assistant or default_assistant_formatter,
+ format_system=format_system or default_user_formatter,
+ format_knowledge=format_knowledge,
+ format_prefix=format_prefix or default_prefix_formatter,
+ default_system=default_system,
+ stop_words=stop_words or [],
+ thought_words=thought_words or ("", ""),
+ efficient_eos=efficient_eos,
+ replace_eos=replace_eos,
+ replace_jinja_template=replace_jinja_template,
+ )
+
+
+def parse_template(tokenizer: "PretrainedTokenizer") -> "Template":
+ r"""Extract a chat template from the tokenizer."""
+
+ def find_diff(short_str: str, long_str: str) -> str:
+ i, j = 0, 0
+ diff = ""
+ while i < len(short_str) and j < len(long_str):
+ if short_str[i] == long_str[j]:
+ i += 1
+ j += 1
+ else:
+ diff += long_str[j]
+ j += 1
+
+ return diff
+
+ prefix = tokenizer.decode(tokenizer.encode("")["input_ids"])
+
+ messages = [{"role": "system", "content": "{{content}}"}]
+ system_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)[len(prefix) :]
+
+ messages = [{"role": "system", "content": ""}, {"role": "user", "content": "{{content}}"}]
+ user_slot_empty_system = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+ user_slot_empty_system = user_slot_empty_system[len(prefix) :]
+
+ messages = [{"role": "user", "content": "{{content}}"}]
+ user_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
+ user_slot = user_slot[len(prefix) :]
+
+ messages = [{"role": "user", "content": "{{content}}"}, {"role": "assistant", "content": "{{content}}"}]
+ assistant_slot = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
+
+ messages = [
+ {"role": "user", "content": "{{content}}"},
+ {"role": "assistant", "content": "{{content}}"},
+ {"role": "user", "content": "{{content}}"},
+ ]
+ assistant_slot = tokenizer.encode(assistant_slot[len(prefix) + len(user_slot) :], add_special_tokens=False)[
+ "input_ids"
+ ]
+
+ # In case of +
+ assistant_slot_further = tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
+ assistant_slot_further = tokenizer.encode(
+ assistant_slot_further[len(prefix) + len(user_slot) :], add_special_tokens=False
+ )["input_ids"]
+
+ # if assistant_slot[-1] in tokenizer.added_tokens_decoder.keys():
+ # #
+ # if assistant_slot[-1] in tokenizer.added_tokens_decoder.keys():
+ # else:
+ # else:
+
+ assistant_slot = tokenizer.decode(os.path.commonprefix([assistant_slot, assistant_slot_further]))
+
+ assistant_slot = assistant_slot.replace("", "").replace("", "").lstrip("\n") # remove thought tags
+
+ if len(user_slot) > len(user_slot_empty_system):
+ default_system = find_diff(user_slot_empty_system, user_slot)
+ sole_system = system_slot.replace("{{content}}", default_system, 1)
+ user_slot = user_slot[len(sole_system) :]
+ else: # if defaut_system is empty, user_slot_empty_system will be longer than user_slot
+ default_system = ""
+
+ return Template(
+ format_user=StringFormatter(slots=[user_slot]),
+ format_assistant=StringFormatter(slots=[assistant_slot]),
+ format_system=StringFormatter(slots=[system_slot]),
+ format_knowledge=KnowledgeFormatter(slots=[user_slot]),
+ format_prefix=EmptyFormatter(slots=[prefix]) if prefix else EmptyFormatter(),
+ default_system=default_system,
+ stop_words=[],
+ thought_words=("", ""),
+ efficient_eos=False,
+ replace_eos=False,
+ replace_jinja_template=False,
+ )
+
+
+def get_template_and_fix_tokenizer(tokenizer: "PretrainedTokenizer", template: str = None) -> "Template":
+ r"""Get chat template and fixes the tokenizer."""
+ if template is None:
+ if isinstance(tokenizer.chat_template1, str):
+ logger.warning("`template` was not specified, try parsing the chat template from the tokenizer.")
+ template = parse_template(tokenizer)
+ else:
+ logger.warning("`template` was not specified, use `empty` template.")
+ template = TEMPLATES["empty"] # placeholder
+ else:
+ if template not in TEMPLATES:
+ raise ValueError(f"Template {template} does not exist.")
+
+ template = TEMPLATES[template]
+
+ template.fix_special_tokens(tokenizer)
+ template.fix_jinja_template(tokenizer)
+ return template
+
+
+"""
+{% if not add_generation_prompt is defined %}
+{% set add_generation_prompt = false %}
+{% endif %}
+{% set loop_messages = messages %}
+{% for message in loop_messages %}
+{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}
+{% if loop.index0 == 0 %}
+{% set content = bos_token + content %}
+{% endif %}
+{{ content }}
+{% endfor %}
+{% if add_generation_prompt %}
+{{ '<|start_header_id|>assistant<|end_header_id|>\n \n' }}
+{% else %}
+{{ eos_token }}
+{% endif %}
+Template(efficient_eos=False, replace_eos=False, replace_jinja_template=False)
+"""
+register_template(
+ name="llama3",
+ format_user=StringFormatter(
+ slots=[
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
+ ]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}<|eot_id|><|end_of_text|>"]),
+ format_system=StringFormatter(
+ slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|><|end_of_text|>"]
+ ),
+ format_prefix=EmptyFormatter(slots=["<|begin_of_text|>"]),
+ replace_jinja_template=True,
+)
+
+
+register_template(
+ name="aquila",
+ format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
+ format_assistant=StringFormatter(slots=["{{content}}###"]),
+ format_system=StringFormatter(slots=["System: {{content}}###"]),
+ default_system=(
+ "A chat between a curious human and an artificial intelligence assistant. "
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
+ ),
+ stop_words=[""],
+)
+
+
+register_template(
+ name="atom",
+ format_user=StringFormatter(
+ slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
+ ),
+ format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
+)
+
+
+register_template(
+ name="baichuan",
+ format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]),
+ efficient_eos=True,
+)
+
+
+register_template(
+ name="45t",
+ format_user=StringFormatter(slots=["User: ", "{{content}}\nAssistant: "]),
+ format_assistant=StringFormatter(slots=["{{content}}", {"token": "<|end_of_sentence|>"}]),
+ format_system=StringFormatter(slots=["{{content}}\n"]),
+ format_prefix=EmptyFormatter(slots=[{"token": "<|begin_of_sentence|>"}]),
+ format_knowledge=KnowledgeFormatter(slots=["User: {{content}}\nAssistant: "]),
+ replace_jinja_template=True,
+)
diff --git a/paddlenlp/transformers/tokenizer_utils.py b/paddlenlp/transformers/tokenizer_utils.py
index f6aca840d386..c011c17c8c70 100644
--- a/paddlenlp/transformers/tokenizer_utils.py
+++ b/paddlenlp/transformers/tokenizer_utils.py
@@ -526,8 +526,18 @@ def _compile_jinja_template(chat_template) -> Template:
def raise_exception(message):
raise TemplateError(message)
+ def regex_findall(s, pattern, multiline=False, dotall=False):
+ flags = 0
+ if multiline:
+ flags |= re.MULTILINE
+ if dotall:
+ flags |= re.DOTALL
+ return re.findall(pattern, s, flags)
+
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, keep_trailing_newline=True)
jinja_env.globals["raise_exception"] = raise_exception
+ jinja_env.filters["regex_findall"] = regex_findall
+ jinja_env.globals.update(regex_findall=regex_findall)
return jinja_env.from_string(chat_template)
def render_conversation(
diff --git a/paddlenlp/transformers/tokenizer_utils_base.py b/paddlenlp/transformers/tokenizer_utils_base.py
index 414f1b6bf103..140f1e4c3b09 100644
--- a/paddlenlp/transformers/tokenizer_utils_base.py
+++ b/paddlenlp/transformers/tokenizer_utils_base.py
@@ -1848,6 +1848,9 @@ def save_pretrained(self, save_directory, filename_prefix: Optional[str] = None,
for file_id in self.resource_files_names.keys():
tokenizer_config.pop(file_id, None)
+ if hasattr(self, "chat_template") and isinstance(self.chat_template, str):
+ tokenizer_config["chat_template"] = self.chat_template
+
# Sanitize AddedTokens
def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
if isinstance(obj, AddedToken):