Skip to content

🤝 Compatibility of the TRL CLI with accelerate arguments #3409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a91c73c
Add FSDP configs
lewtun Apr 17, 2025
debdc8e
Bump accelerate
lewtun Apr 17, 2025
e7f74f7
Update prepare
lewtun Apr 17, 2025
499167d
update version accelerate in test
qgallouedec Apr 21, 2025
6f23aa0
Merge branch 'main' into fix-fsdp2
qgallouedec Apr 21, 2025
fd70eb2
Merge branch 'main' into fix-fsdp2
lewtun May 2, 2025
ca08043
Add full state dict
lewtun May 2, 2025
91ae801
Revert
lewtun May 2, 2025
eb3ed1b
return_remaining_strings=True
qgallouedec May 3, 2025
cd1c0b4
TRLParser compat with subparsers
qgallouedec May 3, 2025
9570ff2
test subpaser config handling
qgallouedec May 3, 2025
770d7b5
allow launch argument in cli args for sft
qgallouedec May 3, 2025
17c0c9f
better comment
qgallouedec May 3, 2025
63b8d73
add accelerate configs
qgallouedec May 4, 2025
f6cb3e5
rewrite the cli doc
qgallouedec May 5, 2025
db6ab73
accelerate config
qgallouedec May 5, 2025
8202f0a
further improve the doc
qgallouedec May 5, 2025
bbd186e
Merge branch 'main' into compat-cli-with-accelerate-args
qgallouedec May 5, 2025
dc4f49c
rm chatgpt blabla
qgallouedec May 5, 2025
e8fa32f
simplify
qgallouedec May 5, 2025
7e42d82
Is it clearer?
qgallouedec May 5, 2025
1fbb56f
other examples
qgallouedec May 5, 2025
6eccaec
even better
qgallouedec May 5, 2025
69b90d5
detail
qgallouedec May 5, 2025
9cfde38
Merge branch 'main' into fix-fsdp2
lewtun May 5, 2025
f0eabff
Bump min version
lewtun May 5, 2025
c584b21
Update docs/source/clis.md
qgallouedec May 5, 2025
1f0ebcc
Merge branch 'fix-fsdp2' into compat-cli-with-accelerate-args
qgallouedec May 5, 2025
7a71b9d
deepspeed_zeroN -> zeroN
qgallouedec May 5, 2025
2a52ee6
remove fsdp qlora and add fsdp1/2
qgallouedec May 5, 2025
be268de
Merge branch 'main' into compat-cli-with-accelerate-args
qgallouedec May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ include LICENSE
include CONTRIBUTING.md
include README.md
recursive-exclude * __pycache__
include trl/templates/*.md
include trl/templates/*.md
include trl/accelerate_configs/*.yaml
214 changes: 167 additions & 47 deletions docs/source/clis.md
Original file line number Diff line number Diff line change
@@ -1,105 +1,225 @@
# Command Line Interfaces (CLIs)

You can use TRL to fine-tune your language model with methods like Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the command line interface (CLI).
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.

Currently supported CLIs are:
Currently supported commands are:

#### Training commands
#### Training Commands

- `trl dpo`: fine-tune a LLM with DPO
- `trl grpo`: fine-tune a LLM with GRPO
- `trl kto`: fine-tune a LLM with KTO
- `trl sft`: fine-tune a LLM with SFT

#### Other commands
#### Other Commands

- `trl env`: get the system information
- `trl vllm-serve`: serve a model with vLLM

## Fine-tuning with the CLI
## Fine-Tuning with the TRL CLI

Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
### Basic Usage

You can launch training directly from the CLI by specifying required arguments like the model and dataset:

<hfoptions id="command_line">
<hfoption id="SFT">

```bash
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb
```

</hfoption>
<hfoption id="DPO">

```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf
```

</hfoption>
</hfoptions>

### Using Configuration Files

To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file:

<hfoptions id="config_file">
<hfoption id="SFT">

```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
```

Launch with:

Before using the `sft` or `dpo` commands make sure to run:
```bash
accelerate config
trl sft --config sft_config.yaml
```
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.

We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
</hfoption>
<hfoption id="DPO">

```yaml
model_name_or_path:
Qwen/Qwen2.5-0.5B
dataset_name:
stanfordnlp/imdb
report_to:
none
learning_rate:
0.0001
lr_scheduler_type:
cosine
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
```

Launch with:

```bash
trl dpo --config dpo_config.yaml
```

Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:
</hfoption>
</hfoptions>

### Scaling Up with Accelerate

TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI.

You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch).

<hfoptions id="launch_args">
<hfoption id="SFT inline">

```bash
trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--num_processes 4
```

</hfoption>
<hfoption id="SFT w/ config file">

```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
num_processes: 4
```

Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
Launch with:

### Supported Arguments
```bash
trl sft --config sft_config.yaml
```

We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
</hfoption>
<hfoption id="DPO inline">

[[autodoc]] ModelConfig
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--num_processes 4
```

You can pass any of these arguments either to the CLI or the YAML file.
</hfoption>
<hfoption id="DPO w/ config file">

### Supervised Fine-tuning (SFT)
```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
num_processes: 4
```

Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
Launch with:

```bash
trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>

### Using `--accelerate_config` for Accelerate Configuration

The SFT CLI is based on the `trl/scripts/sft.py` script.
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:

### Direct Policy Optimization (DPO)
* the name of a predefined config profile (built into TRL), or
* a path to a custom Accelerate YAML config file.

To use the DPO CLI, you need to have a dataset in the TRL format such as
#### Predefined Config Profiles

* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
TRL provides several ready-to-use Accelerate configs to simplify common training setups:

These datasets always have at least three columns `prompt, chosen, rejected`:
| Name | Description |
| ------------ | ----------------------------------- |
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
| `zero1` | DeepSpeed ZeRO Stage 1 |
| `zero2` | DeepSpeed ZeRO Stage 2 |
| `zero3` | DeepSpeed ZeRO Stage 3 |
| `multi_gpu` | Multi-GPU training |
| `single_gpu` | Single-GPU training |

* `prompt` is a list of strings.
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.

#### Example Usage

To do a quick start, you can run the following command:
<hfoptions id="accelerate_config">
<hfoption id="SFT inline">

```bash
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style
trl sft \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name stanfordnlp/imdb \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```

</hfoption>
<hfoption id="SFT w/ config file">

The DPO CLI is based on the `trl/scripts/dpo.py` script.
```yaml
# sft_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: stanfordnlp/imdb
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```

Launch with:

```bash
trl sft --config sft_config.yaml
```

</hfoption>
<hfoption id="DPO inline">

#### Custom preference dataset
```bash
trl dpo \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name anthropic/hh-rlhf \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```

</hfoption>
<hfoption id="DPO w/ config file">

```yaml
# dpo_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: anthropic/hh-rlhf
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```

Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
Launch with:

```bash
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
trl dpo --config dpo_config.yaml
```
</hfoption>
</hfoptions>

## Chat interface
## Chat Interface

<Tip warning={true}>

Expand Down Expand Up @@ -130,7 +250,7 @@ Besides talking to the model there are a few commands you can use:
- `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- `exit`: closes the interface

## Getting the system information
## Getting the System Information

You can get the system information by running the following command:

Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ exclude =
tests*

[options.package_data]
trl = templates/*.md
trl =
templates/*.md
accelerate_configs/*.yaml

[options.extras_require]
bco =
Expand Down
77 changes: 77 additions & 0 deletions tests/test_cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,80 @@ def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, m
self.assertIsInstance(result_args[0], MyDataclass)
self.assertEqual(result_args[0].arg1, 2)
self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"])

@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
def test_subparsers_with_config_defaults(self, mock_yaml_load):
"""Test that config defaults are applied to all subparsers."""
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}

# Create the main parser
parser = TrlParser()

# Add subparsers
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)

# Create a subparser for a specific command
subparsers.add_parser("subcommand", dataclass_types=[MyDataclass])

# Parse with config file
args = ["subcommand", "--config", "config.yaml"]
result_args = parser.parse_args_and_config(args)

# Check main parser arguments
self.assertEqual(len(result_args), 1)

# Check that config values were applied to the subparser
self.assertEqual(result_args[0].arg1, 2) # Default from config
self.assertEqual(result_args[0].arg2, "config_value") # Default from config

@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load):
"""Test that config defaults are applied to all subparsers."""
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}

# Create the main parser
parser = TrlParser()

# Add subparsers
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)

# Create a subparser for a specific command
subparsers.add_parser("subcommand", dataclass_types=[MyDataclass])

# Test with command line arguments overriding config
args = ["subcommand", "--arg1", "3", "--config", "config.yaml"]
result_args = parser.parse_args_and_config(args)

# Command line arguments should override config
self.assertEqual(result_args[0].arg1, 3)
self.assertEqual(result_args[0].arg2, "config_value") # Still from config

@patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value"))
@patch("yaml.safe_load")
def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load):
"""Test that config defaults are applied to all subparsers."""
mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"}

# Create the main parser
parser = TrlParser()

# Add subparsers
subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser)

# Create a subparser for a specific command
subparsers.add_parser("subcommand0", dataclass_types=[MyDataclass])
subparsers.add_parser("subcommand1", dataclass_types=[MyDataclass])

for idx in range(2):
# Parse with config file
args = [f"subcommand{idx}", "--config", "config.yaml"]
result_args = parser.parse_args_and_config(args)

# Check main parser arguments
self.assertEqual(len(result_args), 1)

# Check that config values were applied to the subparser
self.assertEqual(result_args[0].arg1, 2) # Default from config
self.assertEqual(result_args[0].arg2, "config_value") # Default from config
Loading
Loading