Skip to content

💬 Fix setup_chat_format and add clone_chat_template #3404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
7454b1f
fix setup chat format
qgallouedec May 2, 2025
77340e8
Merge branch 'main' into fix-setup-chat-format
qgallouedec May 31, 2025
92e5353
update doc
qgallouedec May 31, 2025
809aa5d
test
qgallouedec May 31, 2025
6692f6f
new func!
qgallouedec May 31, 2025
79f744a
use it in sft script
qgallouedec May 31, 2025
64c7d71
fix import and add example
qgallouedec May 31, 2025
e42c72e
remove type hint
qgallouedec May 31, 2025
e54765e
Update test_dataset_formatting.py
qgallouedec May 31, 2025
1e42c76
Merge branch 'main' into fix-setup-chat-format
qgallouedec Jun 6, 2025
68f11fa
Merge branch 'main' into fix-setup-chat-format
qgallouedec Jun 6, 2025
9b50a7a
Apply suggestions from code review
qgallouedec Jun 6, 2025
0db69b7
Apply suggestions from code review
qgallouedec Jun 6, 2025
f1fbe73
Rename setup_chat_template to clone_chat_template and update referenc…
qgallouedec Jun 6, 2025
a2d9cce
improve qol and ensure added tokens from source
qgallouedec Jun 7, 2025
063f137
propagate fix
qgallouedec Jun 7, 2025
1e57173
Fix value head assertion to check weight shape instead of num_embeddings
qgallouedec Jun 7, 2025
461b17e
Merge branch 'main' into fix-setup-chat-format
qgallouedec Jun 9, 2025
4e8de26
Merge branch 'main' into fix-setup-chat-format
qgallouedec Jun 12, 2025
fc0c9c0
Merge branch 'main' into fix-setup-chat-format
qgallouedec Jun 13, 2025
dfd5c91
Update sft_trainer.md
qgallouedec Jun 13, 2025
8cfeb96
Update utils.py
qgallouedec Jun 13, 2025
03d78b1
Merge branch 'main' into fix-setup-chat-format
qgallouedec Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/model_utils.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Model Utilities

## clone_chat_template

[[autodoc]] clone_chat_template

## get_act_offloading_ctx_manager

[[autodoc]] models.get_act_offloading_ctx_manager
12 changes: 6 additions & 6 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ If you’d like to compute loss on both the prompt **and** the completion while
### Add Special Tokens for Chat Format

Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system, and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
The [`clone_chat_template`] function is a useful utility to prepare a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g., `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the model’s embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g., `64`. If you want to see more formats being supported in the future, please open a GitHub issue on [trl](https://github.com/huggingface/trl)

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import setup_chat_format
from trl import clone_chat_template

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

# Set up the chat format with the default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)
# Set up the chat format
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
```

> [!WARNING]
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply `setup_chat_format()`, as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in `SFTConfig`; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.
> Some base models, like those from Qwen, have a predefined chat template in the model's tokenizer. In these cases, it is not necessary to apply [`clone_chat_template()`], as the tokenizer already handles the formatting. However, it is necessary to align the EOS token with the chat template to ensure the model's responses terminate correctly. In these cases, specify `eos_token` in [`SFTConfig`]; for example, for `Qwen/Qwen2.5-1.5B`, one should set `eos_token="<|im_end|>"`.

With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.

Expand Down
41 changes: 38 additions & 3 deletions tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format
from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_chat_format


class DatasetFormattingTestCase(unittest.TestCase):
Expand Down Expand Up @@ -124,7 +124,7 @@ def setUp(self):

def test_setup_chat_format(self):
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=123
)

_chatml = ChatMlSpecialTokens()
Expand All @@ -135,7 +135,7 @@ def test_setup_chat_format(self):
self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token)
self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token)
self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token)
self.assertEqual((self.model.get_input_embeddings().weight.shape[0] % 64), 0)
self.assertEqual((modified_model.get_input_embeddings().num_embeddings % 123), 0)

def test_example_with_setup_model(self):
modified_model, modified_tokenizer = setup_chat_format(
Expand All @@ -153,3 +153,38 @@ def test_example_with_setup_model(self):
prompt,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n",
)


class CloneChatTemplateTestCase(unittest.TestCase):
def setUp(self):
# This tokenizer doesn't have a chat_template by default
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
# This one has a chat_template by default
self.source = "trl-internal-testing/tiny-Qwen3ForCausalLM"

def test_clone(self):
_, modified_tokenizer = clone_chat_template(self.model, self.tokenizer, self.source)

# Check if special tokens are correctly set
self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>")

def test_clone_with_resize(self):
modified_model, _ = clone_chat_template(self.model, self.tokenizer, self.source, resize_to_multiple_of=123)

# Check that the input embeddings have been resized to a multiple of 123
self.assertEqual((modified_model.get_input_embeddings().num_embeddings % 123), 0)

def test_apply_new_chat_template(self):
_, modified_tokenizer = clone_chat_template(self.model, self.tokenizer, self.source)
messages = [
{"role": "system", "content": "You are helpful"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi, how can I help you?"},
]
prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)

self.assertEqual(
prompt,
"<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nHi, how can I help you?<|im_end|>\n",
)
2 changes: 2 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"AutoModelForCausalLMWithValueHead",
"AutoModelForSeq2SeqLMWithValueHead",
"PreTrainedModelWrapper",
"clone_chat_template",
"create_reference_model",
"setup_chat_format",
],
Expand Down Expand Up @@ -136,6 +137,7 @@
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
clone_chat_template,
create_reference_model,
setup_chat_format,
)
Expand Down
2 changes: 2 additions & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
"utils": [
"SUPPORTED_ARCHITECTURES",
"clone_chat_template",
"prepare_deepspeed",
"prepare_fsdp",
"setup_chat_format",
Expand All @@ -49,6 +50,7 @@
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import (
SUPPORTED_ARCHITECTURES,
clone_chat_template,
prepare_deepspeed,
prepare_fsdp,
setup_chat_format,
Expand Down
81 changes: 79 additions & 2 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch.nn as nn
from packaging import version
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer

from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead

Expand Down Expand Up @@ -82,6 +82,10 @@ def setup_chat_format(
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.

<Tip warning="true">
We recommend using [`clone_chat_template`] instead of this function.
</Tip>

If the model already has a chat template, this will throw an error. If you want to overwrite it, please set `tokenizer.chat_template` to `None`.

Args:
Expand Down Expand Up @@ -116,7 +120,11 @@ def setup_chat_format(

# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)
# Update the model config to use the new eos & bos tokens
if getattr(model, "config", None) is not None:
Expand All @@ -132,6 +140,75 @@ def setup_chat_format(
return model, tokenizer


def clone_chat_template(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
source_tokenizer_path: str,
resize_to_multiple_of: Optional[int] = 64,
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.

This function:
- Copies the chat template from a source tokenizer to the target tokenizer.
- Adds any new tokens from the source tokenizer to the target tokenizer.
- Sets and synchronizes the EOS token across the tokenizer and model.
- Resizes the model's token embeddings to match the new vocabulary size, optionally rounding it up to a multiple of
a specified value.

Args:
model (`PreTrainedModel`):
Model to update.
tokenizer (`PreTrainedTokenizer`):
Tokenizer to update.
source_tokenizer_path (`str`):
Path or identifier of the pretrained tokenizer to clone from.
resize_to_multiple_of (`int` or `None`, *optional*, defaults to `64`):
The embedding layer will be resized to the new vocabulary size. If this is not `None`, it will round up the
new vocabulary size to the nearest multiple of this value.

Returns:
model (`PreTrainedModel`):
Updated model with resized token embeddings and EOS token configured.
tokenizer (`~transformers.PreTrainedTokenizer`):
Updated tokenizer with the chat template and special tokens applied.

Example:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import clone_chat_template

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
```
"""
# Load the source tokenizer containing the desired chat template
tokenizer_source = AutoTokenizer.from_pretrained(source_tokenizer_path)

# Copy the chat template from the source tokenizer
tokenizer.chat_template = tokenizer_source.get_chat_template()

# Ensure all added tokens from the source are available in the target tokenizer
tokenizer.add_tokens(list(tokenizer_source.added_tokens_decoder.values()))

# Set the EOS token from the source tokenizer (important for generation)
tokenizer.eos_token = tokenizer_source.eos_token
model.config.eos_token_id = tokenizer.eos_token_id
model.generation_config.eos_token_id = tokenizer.eos_token_id

# Resize model embeddings to include any new tokens, optionally rounding up to a multiple
model.resize_token_embeddings(
# After studying many tokenizers, we found that len(tokenizer.vocab) is the most reliable way to get the vocab
# size. Avoid using tokenizer.vocab_size or tokenizer.vocab_size + len(tokenizer.added_tokens_encoder),
# as handling of special and added tokens varies across tokenizers.
new_num_tokens=len(tokenizer.vocab),
pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
)

return model, tokenizer


def remove_hooks(model: "DeepSpeedEngine") -> None:
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
if not hasattr(model, "optimizer"): # before the first training step, the model has no optimizer
Expand Down
5 changes: 3 additions & 2 deletions trl/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
SFTConfig,
SFTTrainer,
TrlParser,
clone_chat_template,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
setup_chat_format,
)


Expand Down Expand Up @@ -104,7 +104,8 @@ def main(script_args, training_args, model_args):

# Set default chat template if needed
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
# TODO: source should be passed as an argument
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have found it useful internally to expose a chat_template arg in SFTConfig which allows one to define a custom template or copy-paste one from an existing model. Perhaps we could expose this along with a chat_template_clone arg (or something similar), now that I better understand what your intent was in this PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


################
# Dataset
Expand Down
Loading