Skip to content

fix setup chat format #3404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format
from trl.models.utils import setup_chat_format


class DatasetFormattingTestCase(unittest.TestCase):
Expand Down Expand Up @@ -119,29 +119,24 @@ class SetupChatFormatTestCase(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
# remove built-in chat_template to simulate a model having no chat_template
# Remove built-in chat_template to simulate a model having no chat_template
self.tokenizer.chat_template = None

def test_setup_chat_format(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
)
_, modified_tokenizer = setup_chat_format(self.model, self.tokenizer, format="chatml")

_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")
self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>")
self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>")
self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0)

def test_setup_chat_format_with_resize(self):
modified_model, _ = setup_chat_format(self.model, self.tokenizer, format="chatml", resize_to_multiple_of=123)

# Check that the input embeddings have been resized to a multiple of 123
self.assertEqual((modified_model.get_input_embeddings().weight.shape[0] % 123), 0)

def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model,
self.tokenizer,
)
_, modified_tokenizer = setup_chat_format(self.model, self.tokenizer)
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
Expand Down
62 changes: 34 additions & 28 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import torch.nn as nn
from packaging import version
Expand Down Expand Up @@ -76,58 +76,64 @@ def chat_template(self):
def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
format: str = "chatml",
resize_to_multiple_of: Optional[int] = None,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
embedding layer of the model based on the new special tokens.

If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`.
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
`tokenizer.chat_template` to `None` before calling this function.

Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`int` or `None`): Number to resize the embedding layer to. Defaults to None.
model (`~transformers.PreTrainedModel`):
Model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`):
Tokenizer to be modified.
format (`str`, *optional*, defaults to `"chatml"`):
Format to be set. This can be either one of `{"chatml"}`.
resize_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If not None, the model's embedding layer will be resized to a multiple of this number.

Returns:
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
model (`~transformers.PreTrainedModel`):
Mdified model.
tokenizer (`~transformers.PreTrainedTokenizer`):
Modified tokenizer.
"""
# check if model already had a chat template
# Check if model already had a chat template
if tokenizer.chat_template is not None:
raise ValueError(
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None"
"Chat template is already added to the tokenizer. If you want to overwrite it, please set it to None "
"before calling this function."
)

# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")
raise ValueError(f"Format {chat_format} not supported. Supported formats are: {', '.join(FORMAT_MAPPING.keys())}")

chat_format = FORMAT_MAPPING[format]()

# set special tokens and them
# Set special tokens and chat template
tokenizer.chat_template = chat_format.chat_template
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template

# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
# Resize embedding layer
# This can lead to significant speedup, see https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
new_num_tokens=tokenizer.vocab_size + len(tokenizer.added_tokens_encoder.keys()),
pad_to_multiple_of=resize_to_multiple_of,
)

# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id

# Update the generation config to use the new eos & bos token
if getattr(model, "generation_config", None) is not None:
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id
model.generation_config.bos_token_id = tokenizer.bos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id

return model, tokenizer

Expand Down