From a91c73c661285796de59614309a3079bd430b7f4 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 17 Apr 2025 12:17:47 +0000 Subject: [PATCH 01/25] Add FSDP configs --- examples/accelerate_configs/fsdp1.yaml | 28 +++++++++++++++++++++ examples/accelerate_configs/fsdp2.yaml | 24 ++++++++++++++++++ examples/accelerate_configs/fsdp_qlora.yaml | 25 ------------------ 3 files changed, 52 insertions(+), 25 deletions(-) create mode 100644 examples/accelerate_configs/fsdp1.yaml create mode 100644 examples/accelerate_configs/fsdp2.yaml delete mode 100644 examples/accelerate_configs/fsdp_qlora.yaml diff --git a/examples/accelerate_configs/fsdp1.yaml b/examples/accelerate_configs/fsdp1.yaml new file mode 100644 index 0000000000..c01b0b567b --- /dev/null +++ b/examples/accelerate_configs/fsdp1.yaml @@ -0,0 +1,28 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/fsdp2.yaml b/examples/accelerate_configs/fsdp2.yaml new file mode 100644 index 0000000000..2b544b1e53 --- /dev/null +++ b/examples/accelerate_configs/fsdp2.yaml @@ -0,0 +1,24 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/accelerate_configs/fsdp_qlora.yaml b/examples/accelerate_configs/fsdp_qlora.yaml deleted file mode 100644 index 93b3541470..0000000000 --- a/examples/accelerate_configs/fsdp_qlora.yaml +++ /dev/null @@ -1,25 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -fsdp_config: - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_backward_prefetch: BACKWARD_PRE - fsdp_cpu_ram_efficient_loading: true - fsdp_forward_prefetch: false - fsdp_offload_params: true - fsdp_sharding_strategy: FULL_SHARD - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_sync_module_states: true - fsdp_use_orig_params: false -machine_rank: 0 -main_training_function: main -mixed_precision: 'bf16' -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file From debdc8e87c9eb216419f65a615fecb04b17e044d Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 17 Apr 2025 14:04:46 +0000 Subject: [PATCH 02/25] Bump accelerate --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f3e06c98f5..5c9a727039 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ __version__ = "0.17.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ - "accelerate>=0.34.0", + "accelerate>=1.6.0", "datasets>=3.0.0", "rich", # rich shouldn't be a required package for trl, we should remove it from here "transformers>=4.46.0", From e7f74f79ed9d259074c9116fc8271bed53d8b200 Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Thu, 17 Apr 2025 14:05:16 +0000 Subject: [PATCH 03/25] Update prepare --- trl/models/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/models/utils.py b/trl/models/utils.py index c93e2f9e90..6d837bde3f 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -266,7 +266,7 @@ def prepare_fsdp(model, accelerator): accelerator.state.fsdp_plugin.set_auto_wrap_policy(model) fsdp_plugin = accelerator.state.fsdp_plugin kwargs = { - "sharding_strategy": fsdp_plugin.sharding_strategy, + "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward, "cpu_offload": fsdp_plugin.cpu_offload, "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, "mixed_precision": fsdp_plugin.mixed_precision_policy, From 499167d54f9b36f5f604563d6afa5f8e3a5e4bd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 21 Apr 2025 22:17:49 +0000 Subject: [PATCH 04/25] update version accelerate in test --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 429ddef2cd..ddc69813e7 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -231,7 +231,7 @@ jobs: - name: Install dependencies run: | source .venv/bin/activate - uv pip install accelerate==0.34.0 + uv pip install accelerate==1.6.0 uv pip install datasets==3.0.0 uv pip install transformers==4.46.0 uv pip install ".[dev]" From ca08043183baccb0b8dae34ac9091fa6c69122cb Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 2 May 2025 09:42:34 +0000 Subject: [PATCH 05/25] Add full state dict --- examples/accelerate_configs/fsdp2.yaml | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/accelerate_configs/fsdp2.yaml b/examples/accelerate_configs/fsdp2.yaml index 2b544b1e53..6dd9baf0e5 100644 --- a/examples/accelerate_configs/fsdp2.yaml +++ b/examples/accelerate_configs/fsdp2.yaml @@ -1,3 +1,4 @@ +# Requires accelerate 1.6.0 or higher compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP @@ -9,7 +10,7 @@ fsdp_config: fsdp_cpu_ram_efficient_loading: true fsdp_offload_params: false fsdp_reshard_after_forward: true - fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_state_dict_type: FULL_STATE_DICT fsdp_version: 2 machine_rank: 0 main_training_function: main diff --git a/setup.py b/setup.py index 7665bc47ce..86e906444b 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ __version__ = "0.18.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) REQUIRED_PKGS = [ - "accelerate>=1.6.0", + "accelerate>=0.34.0", "datasets>=3.0.0", "rich", # rich shouldn't be a required package for trl, we should remove it from here "transformers>=4.46.0", From 91ae801716deccdf4556ba6285776aa86bc7571e Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 2 May 2025 09:44:30 +0000 Subject: [PATCH 06/25] Revert --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ddc69813e7..429ddef2cd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -231,7 +231,7 @@ jobs: - name: Install dependencies run: | source .venv/bin/activate - uv pip install accelerate==1.6.0 + uv pip install accelerate==0.34.0 uv pip install datasets==3.0.0 uv pip install transformers==4.46.0 uv pip install ".[dev]" From eb3ed1b065e859ea6465581cddf24ce5d550946a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 3 May 2025 06:03:12 +0000 Subject: [PATCH 07/25] return_remaining_strings=True --- trl/scripts/sft.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index eeaaa0d118..923829f775 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -142,5 +142,8 @@ def make_parser(subparsers: argparse._SubParsersAction = None): if __name__ == "__main__": parser = make_parser() - script_args, training_args, model_args = parser.parse_args_and_config() + # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments. + # To ensure that their parsing does not interfere with the script arguments, parse the arguments with + # `return_remaining_strings=True`, then ignore the remaining strings. + script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True) main(script_args, training_args, model_args) From cd1c0b4c073f36bae3c0301d8f179d6bfdf7d8c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 3 May 2025 06:04:00 +0000 Subject: [PATCH 08/25] TRLParser compat with subparsers --- trl/scripts/utils.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/trl/scripts/utils.py b/trl/scripts/utils.py index 3cc55b0d4a..543331cb90 100644 --- a/trl/scripts/utils.py +++ b/trl/scripts/utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import importlib import inspect import logging @@ -215,20 +216,33 @@ def parse_args_and_config( def set_defaults_with_config(self, **kwargs) -> list[str]: """ - Overrides the parser's default values with those provided via keyword arguments. + Overrides the parser's default values with those provided via keyword arguments, including for subparsers. Any argument with an updated default will also be marked as not required if it was previously required. Returns a list of strings that were not consumed by the parser. """ - # If an argument is in the kwargs, update its default and set it as not required - for action in self._actions: - if action.dest in kwargs: - action.default = kwargs.pop(action.dest) - action.required = False - remaining_strings = [item for key, value in kwargs.items() for item in [f"--{key}", str(value)]] - return remaining_strings + + def apply_defaults(parser, kw): + used_keys = set() + for action in parser._actions: + # Handle subparsers recursively + if isinstance(action, argparse._SubParsersAction): + for subparser in action.choices.values(): + used_keys.update(apply_defaults(subparser, kw)) + elif action.dest in kw: + action.default = kw[action.dest] + action.required = False + used_keys.add(action.dest) + return used_keys + + used_keys = apply_defaults(self, kwargs) + # Remaining args not consumed by the parser + remaining = [ + item for key, value in kwargs.items() if key not in used_keys for item in (f"--{key}", str(value)) + ] + return remaining def get_git_commit_hash(package_name): From 9570ff28b62be952b080e37908c49eafe9fb66fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 3 May 2025 06:38:50 +0000 Subject: [PATCH 09/25] test subpaser config handling --- tests/test_cli_utils.py | 77 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index aab1105a69..5b1cf50cc7 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -163,3 +163,80 @@ def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, m self.assertIsInstance(result_args[0], MyDataclass) self.assertEqual(result_args[0].arg1, 2) self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"]) + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Parse with config file + args = ["subcommand", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Check main parser arguments + self.assertEqual(len(result_args), 1) + + # Check that config values were applied to the subparser + self.assertEqual(result_args[0].arg1, 2) # Default from config + self.assertEqual(result_args[0].arg2, "config_value") # Default from config + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand", dataclass_types=[MyDataclass]) + + # Test with command line arguments overriding config + args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Command line arguments should override config + self.assertEqual(result_args[0].arg1, 3) + self.assertEqual(result_args[0].arg2, "config_value") # Still from config + + @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) + @patch("yaml.safe_load") + def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): + """Test that config defaults are applied to all subparsers.""" + mock_yaml_load.return_value = {"arg1": 2, "arg2": "config_value"} + + # Create the main parser + parser = TrlParser() + + # Add subparsers + subparsers = parser.add_subparsers(dest="command", parser_class=TrlParser) + + # Create a subparser for a specific command + subparsers.add_parser("subcommand0", dataclass_types=[MyDataclass]) + subparsers.add_parser("subcommand1", dataclass_types=[MyDataclass]) + + for idx in range(2): + # Parse with config file + args = [f"subcommand{idx}", "--config", "config.yaml"] + result_args = parser.parse_args_and_config(args) + + # Check main parser arguments + self.assertEqual(len(result_args), 1) + + # Check that config values were applied to the subparser + self.assertEqual(result_args[0].arg1, 2) # Default from config + self.assertEqual(result_args[0].arg2, "config_value") # Default from config From 770d7b5a9b2c8fed241bc1f4a40a6b616614a1b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 3 May 2025 06:54:53 +0000 Subject: [PATCH 10/25] allow launch argument in cli args for sft --- trl/cli.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/trl/cli.py b/trl/cli.py index 3e86de9cf4..05fcbcdfde 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -45,8 +45,11 @@ def main(): make_sft_parser(subparsers) make_vllm_serve_parser(subparsers) - # Parse the arguments - args = parser.parse_args_and_config()[0] + # Parse the arguments; the remaining ones (unknown) are passed to the 'accelerate launch' subparser. + # Duplicates may occur if the same argument is provided in both the config file and CLI. + # For example: unknown = ["--num_processes", "4", "--num_processes", "8"]. + # Deduplication and precedence (CLI over config) are handled later by launch_command_parser. + args, launch_args = parser.parse_args_and_config(return_remaining_strings=True) if args.command == "chat": (chat_args,) = parser.parse_args_and_config() @@ -83,12 +86,14 @@ def main(): launch_command(args) # launch training elif args.command == "sft": - # Get the default args for the launch command + # Get the path to the training script sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") - args = launch_command_parser().parse_args([sft_training_script]) - # Feed the args to the launch command - args.training_script_args = sys.argv[2:] # remove "trl" and "sft" + # This simulates running: `accelerate launch sft.py `. + # Note that the training script args may include launch-related arguments (e.g., `--num_processes`), + # but we rely on the script to ignore any that don't apply to it. + training_script_args = sys.argv[2:] # Remove "trl" and "sft" + args = launch_command_parser().parse_args(launch_args + [sft_training_script] + training_script_args) launch_command(args) # launch training elif args.command == "vllm-serve": From 17c0c9f961e25c1afb144f71ca7dc354c038379d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 3 May 2025 06:57:17 +0000 Subject: [PATCH 11/25] better comment --- trl/cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/cli.py b/trl/cli.py index 05fcbcdfde..252117b2ac 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -45,9 +45,9 @@ def main(): make_sft_parser(subparsers) make_vllm_serve_parser(subparsers) - # Parse the arguments; the remaining ones (unknown) are passed to the 'accelerate launch' subparser. + # Parse the arguments; the remaining ones (`launch_args`) are passed to the 'accelerate launch' subparser. # Duplicates may occur if the same argument is provided in both the config file and CLI. - # For example: unknown = ["--num_processes", "4", "--num_processes", "8"]. + # For example: launch_args = `["--num_processes", "4", "--num_processes", "8"]`. # Deduplication and precedence (CLI over config) are handled later by launch_command_parser. args, launch_args = parser.parse_args_and_config(return_remaining_strings=True) From 63b8d73767dc5f8d12925f853071772898bbb5a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 4 May 2025 22:07:42 +0000 Subject: [PATCH 12/25] add accelerate configs --- MANIFEST.in | 3 ++- setup.cfg | 4 +++- trl/accelerate_configs/deepspeed_zero1.yaml | 20 +++++++++++++++++ trl/accelerate_configs/deepspeed_zero2.yaml | 21 +++++++++++++++++ trl/accelerate_configs/deepspeed_zero3.yaml | 22 ++++++++++++++++++ trl/accelerate_configs/fsdp_qlora.yaml | 25 +++++++++++++++++++++ trl/accelerate_configs/multi_gpu.yaml | 16 +++++++++++++ trl/accelerate_configs/single_gpu.yaml | 16 +++++++++++++ 8 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 trl/accelerate_configs/deepspeed_zero1.yaml create mode 100644 trl/accelerate_configs/deepspeed_zero2.yaml create mode 100644 trl/accelerate_configs/deepspeed_zero3.yaml create mode 100644 trl/accelerate_configs/fsdp_qlora.yaml create mode 100644 trl/accelerate_configs/multi_gpu.yaml create mode 100644 trl/accelerate_configs/single_gpu.yaml diff --git a/MANIFEST.in b/MANIFEST.in index 515f89f484..8855af1a5a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,5 @@ include LICENSE include CONTRIBUTING.md include README.md recursive-exclude * __pycache__ -include trl/templates/*.md \ No newline at end of file +include trl/templates/*.md +include trl/accelerate_configs/*.yaml \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index ee1c685e56..113d6c324c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,7 +35,9 @@ exclude = tests* [options.package_data] -trl = templates/*.md +trl = + templates/*.md + accelerate_configs/*.yaml [options.extras_require] bco = diff --git a/trl/accelerate_configs/deepspeed_zero1.yaml b/trl/accelerate_configs/deepspeed_zero1.yaml new file mode 100644 index 0000000000..d5b5f782fb --- /dev/null +++ b/trl/accelerate_configs/deepspeed_zero1.yaml @@ -0,0 +1,20 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + gradient_accumulation_steps: 1 + zero3_init_flag: false + zero_stage: 1 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/deepspeed_zero2.yaml b/trl/accelerate_configs/deepspeed_zero2.yaml new file mode 100644 index 0000000000..239b14ac3a --- /dev/null +++ b/trl/accelerate_configs/deepspeed_zero2.yaml @@ -0,0 +1,21 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/deepspeed_zero3.yaml b/trl/accelerate_configs/deepspeed_zero3.yaml new file mode 100644 index 0000000000..b5a1201f8a --- /dev/null +++ b/trl/accelerate_configs/deepspeed_zero3.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + deepspeed_multinode_launcher: standard + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: true + zero3_save_16bit_model: true + zero_stage: 3 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/fsdp_qlora.yaml b/trl/accelerate_configs/fsdp_qlora.yaml new file mode 100644 index 0000000000..93b3541470 --- /dev/null +++ b/trl/accelerate_configs/fsdp_qlora.yaml @@ -0,0 +1,25 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: false + fsdp_offload_params: true + fsdp_sharding_strategy: FULL_SHARD + fsdp_state_dict_type: SHARDED_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: false +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/trl/accelerate_configs/multi_gpu.yaml b/trl/accelerate_configs/multi_gpu.yaml new file mode 100644 index 0000000000..15dad9be3b --- /dev/null +++ b/trl/accelerate_configs/multi_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: MULTI_GPU +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/single_gpu.yaml b/trl/accelerate_configs/single_gpu.yaml new file mode 100644 index 0000000000..ebd00a0671 --- /dev/null +++ b/trl/accelerate_configs/single_gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: "NO" +downcast_bf16: 'no' +gpu_ids: all +machine_rank: 0 +main_training_function: main +mixed_precision: 'bf16' +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false From f6cb3e59da9718ee040e121623f9150d468a2f96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 00:16:43 +0000 Subject: [PATCH 13/25] rewrite the cli doc --- docs/source/clis.md | 108 +++++++++++++++++++++++++------------------- 1 file changed, 62 insertions(+), 46 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 68a3366122..9fa4e25935 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -1,8 +1,8 @@ # Command Line Interfaces (CLIs) -You can use TRL to fine-tune your language model with methods like Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the command line interface (CLI). +TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly. -Currently supported CLIs are: +Currently supported commands are: #### Training commands @@ -18,87 +18,103 @@ Currently supported CLIs are: ## Fine-tuning with the CLI -Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task. +To fine-tune a model, for example, you can run: + + + -Before using the `sft` or `dpo` commands make sure to run: ```bash -accelerate config +trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name stanfordnlp/imdb ``` -and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command. -We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command. + + -```yaml -model_name_or_path: - Qwen/Qwen2.5-0.5B -dataset_name: - stanfordnlp/imdb -report_to: - none -learning_rate: - 0.0001 -lr_scheduler_type: - cosine +```bash +trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name ... ``` -Save that config in a `.yaml` and get started immediately! An example CLI config is available as `examples/cli_configs/example_config.yaml`. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder: + + -```bash -trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts -``` +### Configuration file -Will force-use `cosine_with_restarts` for `lr_scheduler_type`. +You can also configure your training setup using a YAML configuration file, which helps keep your command-line usage clean and reproducible. Below is an example of a minimal configuration file: -### Supported Arguments + + -We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`: +```yaml +# example_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +``` -[[autodoc]] ModelConfig +To launch training with this config, run: + +```bash +trl sft --config example_config.yaml +``` -You can pass any of these arguments either to the CLI or the YAML file. + + -### Supervised Fine-tuning (SFT) +```yaml +# example_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: ... +``` -Follow the basic instructions above and run `trl sft --output_dir <*args>`: +To launch training with this config, run: ```bash -trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb +trl dpo --config example_config.yaml ``` -The SFT CLI is based on the `trl/scripts/sft.py` script. + + -### Direct Policy Optimization (DPO) +### Use the CLI for distributed training -To use the DPO CLI, you need to have a dataset in the TRL format such as +The TRL CLI supports **all the arguments** of `accelerate launch`. See https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch. Consequelntly you can easily distribute the training leveraging `accelerate`. Example with `num_processes`: -* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-helpful-base-trl-style -* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style -These datasets always have at least three columns `prompt, chosen, rejected`: - -* `prompt` is a list of strings. -* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating) -* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating) + + +```bash +trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name stanfordnlp/imdb --num_processes 4 +``` -To do a quick start, you can run the following command: + + ```bash -trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style +trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name ... --num_processes 4 ``` + + -The DPO CLI is based on the `trl/scripts/dpo.py` script. +TRL provides some predefined configurations for distrubtued training. To use then simply use the `--accelerate_config` argument. For example, to use the DeepSpeed ZeRO Stage 2, run: + + -#### Custom preference dataset +```bash +trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name stanfordnlp/imdb --accelerate_config deepspeed_zero2 +``` -Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`): + + ```bash -python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org +trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name ... --accelerate_config deepspeed_zero2 ``` + + + ## Chat interface From db6ab7380eb9b5f48f84bb385e00fac1e5fd7dcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 00:32:15 +0000 Subject: [PATCH 14/25] accelerate config --- trl/cli.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/trl/cli.py b/trl/cli.py index 252117b2ac..780a6e7a5e 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import importlib.resources as resources import sys import warnings @@ -51,14 +51,31 @@ def main(): # Deduplication and precedence (CLI over config) are handled later by launch_command_parser. args, launch_args = parser.parse_args_and_config(return_remaining_strings=True) + # Replace `--accelerate_config foo` with `--config_file trl/accelerate_configs/foo.yaml` if it is present in the + # launch_args. It allows the user to use predefined accelerate configs from the `trl` package. + if "--accelerate_config" in launch_args: + # Get the index of the '--accelerate_config' argument and the corresponding config name + config_index = launch_args.index("--accelerate_config") + config_name = launch_args[config_index + 1] + + # Construct the file path from the package resources + accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml") + + # Remove '--accelerate_config' and its corresponding config name + launch_args.pop(config_index) + launch_args.pop(config_index) + + # Insert '--config_file' and the absolute path to the front of the list + launch_args = ["--config_file", str(accelerate_config_path)] + launch_args + if args.command == "chat": (chat_args,) = parser.parse_args_and_config() chat_main(chat_args) if args.command == "dpo": # Get the default args for the launch command - dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py") - args = launch_command_parser().parse_args([dpo_training_script]) + dpo_training_script = resources.files("trl.scripts").joinpath("dpo.py") + args = launch_command_parser().parse_args([str(dpo_training_script)]) # Feed the args to the launch command args.training_script_args = sys.argv[2:] # remove "trl" and "dpo" @@ -69,8 +86,8 @@ def main(): elif args.command == "grpo": # Get the default args for the launch command - grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py") - args = launch_command_parser().parse_args([grpo_training_script]) + grpo_training_script = resources.files("trl.scripts").joinpath("grpo.py") + args = launch_command_parser().parse_args([str(grpo_training_script)]) # Feed the args to the launch command args.training_script_args = sys.argv[2:] # remove "trl" and "grpo" @@ -78,8 +95,8 @@ def main(): elif args.command == "kto": # Get the default args for the launch command - kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py") - args = launch_command_parser().parse_args([kto_training_script]) + kto_training_script = resources.files("trl.scripts").joinpath("kto.py") + args = launch_command_parser().parse_args([str(kto_training_script)]) # Feed the args to the launch command args.training_script_args = sys.argv[2:] # remove "trl" and "kto" @@ -87,13 +104,13 @@ def main(): elif args.command == "sft": # Get the path to the training script - sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") + sft_training_script = resources.files("trl.scripts").joinpath("sft.py") # This simulates running: `accelerate launch sft.py `. # Note that the training script args may include launch-related arguments (e.g., `--num_processes`), # but we rely on the script to ignore any that don't apply to it. training_script_args = sys.argv[2:] # Remove "trl" and "sft" - args = launch_command_parser().parse_args(launch_args + [sft_training_script] + training_script_args) + args = launch_command_parser().parse_args(launch_args + [str(sft_training_script)] + training_script_args) launch_command(args) # launch training elif args.command == "vllm-serve": From 8202f0a55c56be572b1f9ad6bcc5652a95b021e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 00:32:31 +0000 Subject: [PATCH 15/25] further improve the doc --- docs/source/clis.md | 77 +++++++++++++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 9fa4e25935..80cb92c4bb 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -15,101 +15,138 @@ Currently supported commands are: - `trl env`: get the system information - `trl vllm-serve`: serve a model with vLLM +Absolutely — here's a **refined, ultra-clear, and developer-friendly** version of your CLI documentation. It keeps things concise while improving clarity, flow, and formatting. -## Fine-tuning with the CLI +--- -To fine-tune a model, for example, you can run: +## Fine-Tuning with the TRL CLI + +### Basic Usage + +You can launch training directly from the CLI by specifying required arguments like the model and dataset: ```bash -trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name stanfordnlp/imdb +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb ``` ```bash -trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name ... +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf ``` -### Configuration file +### Using Configuration Files -You can also configure your training setup using a YAML configuration file, which helps keep your command-line usage clean and reproducible. Below is an example of a minimal configuration file: +To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file: ```yaml -# example_config.yaml +# sft_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: stanfordnlp/imdb ``` -To launch training with this config, run: +Launch with: ```bash -trl sft --config example_config.yaml +trl sft --config sft_config.yaml ``` ```yaml -# example_config.yaml +# dpo_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B -dataset_name: ... +dataset_name: anthropic/hh-rlhf +output_dir: ./results/dpo +learning_rate: 5e-7 ``` -To launch training with this config, run: +Launch with: ```bash -trl dpo --config example_config.yaml +trl dpo --config dpo_config.yaml ``` -### Use the CLI for distributed training +### Scaling Up with Accelerate -The TRL CLI supports **all the arguments** of `accelerate launch`. See https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch. Consequelntly you can easily distribute the training leveraging `accelerate`. Example with `num_processes`: +TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI. +You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. ```bash -trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name stanfordnlp/imdb --num_processes 4 +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb \ + --num_processes 4 ``` ```bash -trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name ... --num_processes 4 +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf \ + --num_processes 4 ``` -TRL provides some predefined configurations for distrubtued training. To use then simply use the `--accelerate_config` argument. For example, to use the DeepSpeed ZeRO Stage 2, run: +### Using Predefined Accelerate Configs + +TRL includes built-in Accelerate configuration profiles to simplify distributed training. Use the `--accelerate_config` flag to load one by name: + +#### Available presets: + +* `deepspeed_zero1` — DeepSpeed ZeRO Stage 1 +* `deepspeed_zero2` — DeepSpeed ZeRO Stage 2 +* `deepspeed_zero3` — DeepSpeed ZeRO Stage 3 +* `fsdp_qlora` — Fully Sharded Data Parallel with QLoRA +* `multi_gpu` — Multi-GPU training +* `single_gpu` — Single-GPU training + +#### Example usage: ```bash -trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name stanfordnlp/imdb --accelerate_config deepspeed_zero2 +trl sft \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name stanfordnlp/imdb \ + --accelerate_config deepspeed_zero2 ``` ```bash -trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name ... --accelerate_config deepspeed_zero2 +trl dpo \ + --model_name_or_path Qwen/Qwen2.5-0.5B \ + --dataset_name anthropic/hh-rlhf \ + --accelerate_config deepspeed_zero3 ``` From dc4f49cf48d6a856f1f920806eccf87d1cd24eed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 00:34:30 +0000 Subject: [PATCH 16/25] rm chatgpt blabla --- docs/source/clis.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 80cb92c4bb..865ef0d2f4 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -15,9 +15,6 @@ Currently supported commands are: - `trl env`: get the system information - `trl vllm-serve`: serve a model with vLLM -Absolutely — here's a **refined, ultra-clear, and developer-friendly** version of your CLI documentation. It keeps things concise while improving clarity, flow, and formatting. - ---- ## Fine-Tuning with the TRL CLI From e8fa32f8f25d37fce5a1a005a5dc4ba12de4d57b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 00:35:52 +0000 Subject: [PATCH 17/25] simplify --- docs/source/clis.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 865ef0d2f4..f2c1181afb 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -69,8 +69,6 @@ trl sft --config sft_config.yaml # dpo_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: anthropic/hh-rlhf -output_dir: ./results/dpo -learning_rate: 5e-7 ``` Launch with: From 7e42d823c1fd4360a1f164cef6c4dda34c040aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 01:02:47 +0000 Subject: [PATCH 18/25] Is it clearer? --- docs/source/clis.md | 49 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index f2c1181afb..6c6d125ead 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -23,7 +23,7 @@ Currently supported commands are: You can launch training directly from the CLI by specifying required arguments like the model and dataset: - + ```bash trl sft \ @@ -32,7 +32,7 @@ trl sft \ ``` - + ```bash trl dpo \ @@ -48,7 +48,7 @@ trl dpo \ To keep your CLI commands clean and reproducible, you can define all training arguments in a YAML configuration file: - + ```yaml # sft_config.yaml @@ -63,7 +63,7 @@ trl sft --config sft_config.yaml ``` - + ```yaml # dpo_config.yaml @@ -84,10 +84,10 @@ trl dpo --config dpo_config.yaml TRL CLI natively supports [🤗 Accelerate](https://huggingface.co/docs/accelerate), making it easy to scale training across multiple GPUs, machines, or use advanced setups like DeepSpeed — all from the same CLI. -You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. +You can pass any `accelerate launch` arguments directly to `trl`, such as `--num_processes`. For more information see [Using accelerate launch](https://huggingface.co/docs/accelerate/en/basic_tutorials/launch#using-accelerate-launch). - + ```bash trl sft \ @@ -97,7 +97,23 @@ trl sft \ ``` - + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +num_processes: 4 +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + + ```bash trl dpo \ @@ -106,6 +122,21 @@ trl dpo \ --num_processes 4 ``` + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +num_processes: 4 +``` + +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` @@ -125,7 +156,7 @@ TRL includes built-in Accelerate configuration profiles to simplify distributed #### Example usage: - + ```bash trl sft \ @@ -135,7 +166,7 @@ trl sft \ ``` - + ```bash trl dpo \ From 1fbb56fb5ed3244cb3bc17daeabf5eb865dd1a28 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 01:10:51 +0000 Subject: [PATCH 19/25] other examples --- docs/source/clis.md | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 6c6d125ead..4350014c28 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -156,7 +156,7 @@ TRL includes built-in Accelerate configuration profiles to simplify distributed #### Example usage: - + ```bash trl sft \ @@ -166,15 +166,45 @@ trl sft \ ``` - + + +```yaml +# sft_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: stanfordnlp/imdb +accelerate_config: deepspeed_zero2 +``` + +Launch with: + +```bash +trl sft --config sft_config.yaml +``` + + ```bash trl dpo \ --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name anthropic/hh-rlhf \ - --accelerate_config deepspeed_zero3 + --accelerate_config deepspeed_zero2 +``` + + + + +```yaml +# dpo_config.yaml +model_name_or_path: Qwen/Qwen2.5-0.5B +dataset_name: anthropic/hh-rlhf +accelerate_config: deepspeed_zero2 ``` +Launch with: + +```bash +trl dpo --config dpo_config.yaml +``` From 6eccaec2f0705bfeea10983086fd07e6e5026f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 01:28:42 +0000 Subject: [PATCH 20/25] even better --- docs/source/clis.md | 48 +++++++++++++++++++++++++++------------------ trl/cli.py | 14 +++++++++++-- 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 4350014c28..03d43e559b 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -4,14 +4,14 @@ TRL provides a powerful command-line interface (CLI) to fine-tune large language Currently supported commands are: -#### Training commands +#### Training Commands - `trl dpo`: fine-tune a LLM with DPO - `trl grpo`: fine-tune a LLM with GRPO - `trl kto`: fine-tune a LLM with KTO - `trl sft`: fine-tune a LLM with SFT -#### Other commands +#### Other Commands - `trl env`: get the system information - `trl vllm-serve`: serve a model with vLLM @@ -140,29 +140,38 @@ trl dpo --config dpo_config.yaml -### Using Predefined Accelerate Configs +### Using `--accelerate_config` for Accelerate Configuration -TRL includes built-in Accelerate configuration profiles to simplify distributed training. Use the `--accelerate_config` flag to load one by name: +The `--accelerate_config` flag in TRL lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either: -#### Available presets: +* the name of a predefined config profile (built into TRL), or +* a path to a custom Accelerate YAML config file. -* `deepspeed_zero1` — DeepSpeed ZeRO Stage 1 -* `deepspeed_zero2` — DeepSpeed ZeRO Stage 2 -* `deepspeed_zero3` — DeepSpeed ZeRO Stage 3 -* `fsdp_qlora` — Fully Sharded Data Parallel with QLoRA -* `multi_gpu` — Multi-GPU training -* `single_gpu` — Single-GPU training +#### Predefined Config Profiles -#### Example usage: +TRL provides several ready-to-use Accelerate configs to simplify common training setups: - +| Name | Description | +| ----------------- | -------------------------------------- | +| `deepspeed_zero1` | DeepSpeed ZeRO Stage 1 | +| `deepspeed_zero2` | DeepSpeed ZeRO Stage 2 | +| `deepspeed_zero3` | DeepSpeed ZeRO Stage 3 | +| `fsdp_qlora` | Fully Sharded Data Parallel with QLoRA | +| `multi_gpu` | Multi-GPU training | +| `single_gpu` | Single-GPU training | + +To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`. + +#### Example Usage + + ```bash trl sft \ --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name stanfordnlp/imdb \ - --accelerate_config deepspeed_zero2 + --accelerate_config deepspeed_zero2 # or path/to/my/accelerate/config.yaml ``` @@ -172,7 +181,7 @@ trl sft \ # sft_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: stanfordnlp/imdb -accelerate_config: deepspeed_zero2 +accelerate_config: deepspeed_zero2 # or path/to/my/accelerate/config.yaml ``` Launch with: @@ -180,6 +189,7 @@ Launch with: ```bash trl sft --config sft_config.yaml ``` + @@ -187,7 +197,7 @@ trl sft --config sft_config.yaml trl dpo \ --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name anthropic/hh-rlhf \ - --accelerate_config deepspeed_zero2 + --accelerate_config deepspeed_zero2 # or path/to/my/accelerate/config.yaml ``` @@ -197,7 +207,7 @@ trl dpo \ # dpo_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: anthropic/hh-rlhf -accelerate_config: deepspeed_zero2 +accelerate_config: deepspeed_zero2 # or path/to/my/accelerate/config.yaml ``` Launch with: @@ -208,7 +218,7 @@ trl dpo --config dpo_config.yaml -## Chat interface +## Chat Interface @@ -239,7 +249,7 @@ Besides talking to the model there are a few commands you can use: - `save` or `save {SAVE_NAME}`: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided - `exit`: closes the interface -## Getting the system information +## Getting the System Information You can get the system information by running the following command: diff --git a/trl/cli.py b/trl/cli.py index 780a6e7a5e..7d85d570e6 100644 --- a/trl/cli.py +++ b/trl/cli.py @@ -13,6 +13,7 @@ # limitations under the License. import importlib.resources as resources +import os import sys import warnings @@ -58,8 +59,17 @@ def main(): config_index = launch_args.index("--accelerate_config") config_name = launch_args[config_index + 1] - # Construct the file path from the package resources - accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml") + # If the config_name correspond to a path in the filesystem, we don't want to override it + if os.path.isfile(config_name): + accelerate_config_path = config_name + elif resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml").exists(): + # Get the predefined accelerate config path from the package resources + accelerate_config_path = resources.files("trl.accelerate_configs").joinpath(f"{config_name}.yaml") + else: + raise ValueError( + f"Accelerate config {config_name} is neither a file nor a valid config in the `trl` package. " + "Please provide a valid config name or a path to a config file." + ) # Remove '--accelerate_config' and its corresponding config name launch_args.pop(config_index) From 69b90d540190d0a6285b19cf9821fb9233dedeb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 01:52:57 +0000 Subject: [PATCH 21/25] detail --- docs/source/clis.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 03d43e559b..7e4636d571 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -142,7 +142,7 @@ trl dpo --config dpo_config.yaml ### Using `--accelerate_config` for Accelerate Configuration -The `--accelerate_config` flag in TRL lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either: +The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either: * the name of a predefined config profile (built into TRL), or * a path to a custom Accelerate YAML config file. From f0eabff2f70fb59353880401abe8f6c4bdfe533d Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Mon, 5 May 2025 10:30:40 +0000 Subject: [PATCH 22/25] Bump min version --- examples/accelerate_configs/fsdp2.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/accelerate_configs/fsdp2.yaml b/examples/accelerate_configs/fsdp2.yaml index 6dd9baf0e5..af498f3ece 100644 --- a/examples/accelerate_configs/fsdp2.yaml +++ b/examples/accelerate_configs/fsdp2.yaml @@ -1,4 +1,4 @@ -# Requires accelerate 1.6.0 or higher +# Requires accelerate 1.7.0 or higher compute_environment: LOCAL_MACHINE debug: false distributed_type: FSDP From c584b21e2c65306fa636d1c074ec1209ea111476 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Mon, 5 May 2025 09:04:43 -0700 Subject: [PATCH 23/25] Update docs/source/clis.md Co-authored-by: lewtun --- docs/source/clis.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/clis.md b/docs/source/clis.md index 7e4636d571..40491fb053 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -153,7 +153,7 @@ TRL provides several ready-to-use Accelerate configs to simplify common training | Name | Description | | ----------------- | -------------------------------------- | -| `deepspeed_zero1` | DeepSpeed ZeRO Stage 1 | +| `zero1` | DeepSpeed ZeRO Stage 1 | | `deepspeed_zero2` | DeepSpeed ZeRO Stage 2 | | `deepspeed_zero3` | DeepSpeed ZeRO Stage 3 | | `fsdp_qlora` | Fully Sharded Data Parallel with QLoRA | From 7a71b9d0001bcb904482377500a315926d6fedc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 16:28:55 +0000 Subject: [PATCH 24/25] deepspeed_zeroN -> zeroN --- docs/source/clis.md | 25 ++++++++++--------- .../{deepspeed_zero1.yaml => zero1.yaml} | 0 .../{deepspeed_zero2.yaml => zero2.yaml} | 0 .../{deepspeed_zero3.yaml => zero3.yaml} | 0 4 files changed, 13 insertions(+), 12 deletions(-) rename trl/accelerate_configs/{deepspeed_zero1.yaml => zero1.yaml} (100%) rename trl/accelerate_configs/{deepspeed_zero2.yaml => zero2.yaml} (100%) rename trl/accelerate_configs/{deepspeed_zero3.yaml => zero3.yaml} (100%) diff --git a/docs/source/clis.md b/docs/source/clis.md index 40491fb053..2dac2a8524 100644 --- a/docs/source/clis.md +++ b/docs/source/clis.md @@ -151,14 +151,15 @@ The `--accelerate_config` flag lets you easily configure distributed training wi TRL provides several ready-to-use Accelerate configs to simplify common training setups: -| Name | Description | -| ----------------- | -------------------------------------- | -| `zero1` | DeepSpeed ZeRO Stage 1 | -| `deepspeed_zero2` | DeepSpeed ZeRO Stage 2 | -| `deepspeed_zero3` | DeepSpeed ZeRO Stage 3 | -| `fsdp_qlora` | Fully Sharded Data Parallel with QLoRA | -| `multi_gpu` | Multi-GPU training | -| `single_gpu` | Single-GPU training | +| Name | Description | +| ------------ | ----------------------------------- | +| `fsdp1` | Fully Sharded Data Parallel Stage 1 | +| `fsdp2` | Fully Sharded Data Parallel Stage 2 | +| `zero1` | DeepSpeed ZeRO Stage 1 | +| `zero2` | DeepSpeed ZeRO Stage 2 | +| `zero3` | DeepSpeed ZeRO Stage 3 | +| `multi_gpu` | Multi-GPU training | +| `single_gpu` | Single-GPU training | To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`. @@ -171,7 +172,7 @@ To use one of these, just pass the name to `--accelerate_config`. TRL will autom trl sft \ --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name stanfordnlp/imdb \ - --accelerate_config deepspeed_zero2 # or path/to/my/accelerate/config.yaml + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml ``` @@ -181,7 +182,7 @@ trl sft \ # sft_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: stanfordnlp/imdb -accelerate_config: deepspeed_zero2 # or path/to/my/accelerate/config.yaml +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml ``` Launch with: @@ -197,7 +198,7 @@ trl sft --config sft_config.yaml trl dpo \ --model_name_or_path Qwen/Qwen2.5-0.5B \ --dataset_name anthropic/hh-rlhf \ - --accelerate_config deepspeed_zero2 # or path/to/my/accelerate/config.yaml + --accelerate_config zero2 # or path/to/my/accelerate/config.yaml ``` @@ -207,7 +208,7 @@ trl dpo \ # dpo_config.yaml model_name_or_path: Qwen/Qwen2.5-0.5B dataset_name: anthropic/hh-rlhf -accelerate_config: deepspeed_zero2 # or path/to/my/accelerate/config.yaml +accelerate_config: zero2 # or path/to/my/accelerate/config.yaml ``` Launch with: diff --git a/trl/accelerate_configs/deepspeed_zero1.yaml b/trl/accelerate_configs/zero1.yaml similarity index 100% rename from trl/accelerate_configs/deepspeed_zero1.yaml rename to trl/accelerate_configs/zero1.yaml diff --git a/trl/accelerate_configs/deepspeed_zero2.yaml b/trl/accelerate_configs/zero2.yaml similarity index 100% rename from trl/accelerate_configs/deepspeed_zero2.yaml rename to trl/accelerate_configs/zero2.yaml diff --git a/trl/accelerate_configs/deepspeed_zero3.yaml b/trl/accelerate_configs/zero3.yaml similarity index 100% rename from trl/accelerate_configs/deepspeed_zero3.yaml rename to trl/accelerate_configs/zero3.yaml From 2a52ee64261148260b8f08130634341e14e2a375 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 5 May 2025 16:29:20 +0000 Subject: [PATCH 25/25] remove fsdp qlora and add fsdp1/2 --- trl/accelerate_configs/fsdp1.yaml | 28 ++++++++++++++++++++++++++ trl/accelerate_configs/fsdp2.yaml | 25 +++++++++++++++++++++++ trl/accelerate_configs/fsdp_qlora.yaml | 25 ----------------------- 3 files changed, 53 insertions(+), 25 deletions(-) create mode 100644 trl/accelerate_configs/fsdp1.yaml create mode 100644 trl/accelerate_configs/fsdp2.yaml delete mode 100644 trl/accelerate_configs/fsdp_qlora.yaml diff --git a/trl/accelerate_configs/fsdp1.yaml b/trl/accelerate_configs/fsdp1.yaml new file mode 100644 index 0000000000..c01b0b567b --- /dev/null +++ b/trl/accelerate_configs/fsdp1.yaml @@ -0,0 +1,28 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_backward_prefetch: BACKWARD_PRE + fsdp_cpu_ram_efficient_loading: true + fsdp_forward_prefetch: true + fsdp_offload_params: false + fsdp_reshard_after_forward: FULL_SHARD + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sync_module_states: true + fsdp_use_orig_params: true + fsdp_version: 1 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/fsdp2.yaml b/trl/accelerate_configs/fsdp2.yaml new file mode 100644 index 0000000000..af498f3ece --- /dev/null +++ b/trl/accelerate_configs/fsdp2.yaml @@ -0,0 +1,25 @@ +# Requires accelerate 1.7.0 or higher +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +enable_cpu_affinity: false +fsdp_config: + fsdp_activation_checkpointing: false + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/trl/accelerate_configs/fsdp_qlora.yaml b/trl/accelerate_configs/fsdp_qlora.yaml deleted file mode 100644 index 93b3541470..0000000000 --- a/trl/accelerate_configs/fsdp_qlora.yaml +++ /dev/null @@ -1,25 +0,0 @@ -compute_environment: LOCAL_MACHINE -debug: false -distributed_type: FSDP -downcast_bf16: 'no' -fsdp_config: - fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP - fsdp_backward_prefetch: BACKWARD_PRE - fsdp_cpu_ram_efficient_loading: true - fsdp_forward_prefetch: false - fsdp_offload_params: true - fsdp_sharding_strategy: FULL_SHARD - fsdp_state_dict_type: SHARDED_STATE_DICT - fsdp_sync_module_states: true - fsdp_use_orig_params: false -machine_rank: 0 -main_training_function: main -mixed_precision: 'bf16' -num_machines: 1 -num_processes: 8 -rdzv_backend: static -same_network: true -tpu_env: [] -tpu_use_cluster: false -tpu_use_sudo: false -use_cpu: false \ No newline at end of file