From ca793028c69433eae405009c5ebb790c6c2d40c4 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Fri, 29 Mar 2024 11:36:08 +0800 Subject: [PATCH] release v0.6.1 --- src/llmtuner/__init__.py | 2 +- src/llmtuner/train/dpo/workflow.py | 1 + src/llmtuner/train/ppo/trainer.py | 94 ++++++++++++++++++++++++++++-- src/llmtuner/train/ppo/workflow.py | 51 +--------------- src/llmtuner/train/utils.py | 4 +- 5 files changed, 95 insertions(+), 57 deletions(-) diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 6852ae2fdf..903e82ad90 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -7,5 +7,5 @@ from .webui import create_ui, create_web_demo -__version__ = "0.6.0" +__version__ = "0.6.1" __all__ = ["create_app", "ChatModel", "Evaluator", "export_model", "run_exp", "create_ui", "create_web_demo"] diff --git a/src/llmtuner/train/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py index 7014177a6c..851de9820b 100644 --- a/src/llmtuner/train/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -28,6 +28,7 @@ def run_dpo( tokenizer = load_tokenizer(model_args) dataset = get_dataset(tokenizer, model_args, data_args, training_args, stage="rm") model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + data_collator = DPODataCollatorWithPadding( tokenizer=tokenizer, pad_to_multiple_of=8, diff --git a/src/llmtuner/train/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py index a06d7ef14e..de87532ad0 100644 --- a/src/llmtuner/train/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -6,20 +6,23 @@ import torch from tqdm import tqdm from transformers import GenerationConfig, Trainer, TrainerControl, TrainerState +from transformers.optimization import get_scheduler from transformers.trainer_pt_utils import remove_dummy_checkpoint from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME -from trl import PPOTrainer +from trl import PPOConfig, PPOTrainer from trl.core import PPODecorators, logprobs_from_logits from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.logging import get_logger from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor +from ..utils import create_custom_optimzer, create_custom_scheduler from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm if TYPE_CHECKING: - from transformers import Seq2SeqTrainingArguments, TrainerCallback + from datasets import Dataset + from transformers import DataCollatorWithPadding, PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments @@ -40,10 +43,53 @@ def __init__( finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", callbacks: List["TrainerCallback"], - reward_model: "AutoModelForCausalLMWithValueHead", - **kwargs, + model: "AutoModelForCausalLMWithValueHead", + reward_model: Optional["AutoModelForCausalLMWithValueHead"], + ref_model: Optional["AutoModelForCausalLMWithValueHead"], + tokenizer: "PreTrainedTokenizer", + dataset: "Dataset", + data_collator: "DataCollatorWithPadding", ): - PPOTrainer.__init__(self, **kwargs) + backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps + ppo_config = PPOConfig( + model_name=model_args.model_name_or_path, + learning_rate=training_args.learning_rate, + mini_batch_size=training_args.per_device_train_batch_size, + batch_size=backward_batch_size * finetuning_args.ppo_buffer_size, + gradient_accumulation_steps=training_args.gradient_accumulation_steps, + ppo_epochs=finetuning_args.ppo_epochs, + max_grad_norm=training_args.max_grad_norm, + seed=training_args.seed, + optimize_device_cache=True, + target=finetuning_args.ppo_target, + use_score_scaling=finetuning_args.ppo_score_norm, + use_score_norm=finetuning_args.ppo_score_norm, + whiten_rewards=finetuning_args.ppo_whiten_rewards, + accelerator_kwargs={"step_scheduler_with_optimizer": False}, + log_with=training_args.report_to[0] if training_args.report_to is not None else None, + project_kwargs={"logging_dir": training_args.logging_dir}, + ) + + # Create optimizer and scheduler + if training_args.max_steps > 0: + num_training_steps = training_args.max_steps + else: + total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size + num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) + + optimizer = self.create_optimizer(model, training_args, finetuning_args) + scheduler = self.create_scheduler(training_args, num_training_steps, optimizer) + + PPOTrainer.__init__( + self, + config=ppo_config, + model=model, + ref_model=ref_model, + tokenizer=tokenizer, + dataset=dataset, + data_collator=data_collator, + lr_scheduler=scheduler, + ) self.args = training_args self.model_args = model_args @@ -205,6 +251,44 @@ def ppo_train(self, resume_from_checkpoint: Optional[str] = None) -> None: self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) ) + def create_optimizer( + self, + model: "AutoModelForCausalLMWithValueHead", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + ) -> "torch.optim.Optimizer": + optimizer = create_custom_optimzer(model, training_args, finetuning_args) + if optimizer is None: + decay_params, nodecay_params = [], [] + decay_param_names = self.get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=nodecay_params), + dict(params=decay_params, weight_decay=training_args.weight_decay), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + + return optimizer + + def create_scheduler( + self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer" + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(training_args, num_training_steps, optimizer) + lr_scheduler = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + return lr_scheduler + @torch.no_grad() def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" diff --git a/src/llmtuner/train/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py index 0e03086b30..d5854073e3 100644 --- a/src/llmtuner/train/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -1,19 +1,15 @@ # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py -import math from typing import TYPE_CHECKING, List, Optional -from torch.optim import AdamW from transformers import DataCollatorWithPadding -from transformers.optimization import get_scheduler -from trl import PPOConfig from ...data import get_dataset from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_custom_optimzer, create_custom_scheduler, create_ref_model, create_reward_model +from ..utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer @@ -42,46 +38,6 @@ def run_ppo( ref_model = create_ref_model(model_args, finetuning_args, add_valuehead=True) reward_model = create_reward_model(model, model_args, finetuning_args) - # Create ppo config - backward_batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps - ppo_config = PPOConfig( - model_name=model_args.model_name_or_path, - learning_rate=training_args.learning_rate, - mini_batch_size=training_args.per_device_train_batch_size, - batch_size=backward_batch_size * finetuning_args.ppo_buffer_size, - gradient_accumulation_steps=training_args.gradient_accumulation_steps, - ppo_epochs=finetuning_args.ppo_epochs, - max_grad_norm=training_args.max_grad_norm, - seed=training_args.seed, - optimize_device_cache=True, - target=finetuning_args.ppo_target, - use_score_scaling=finetuning_args.ppo_score_norm, - use_score_norm=finetuning_args.ppo_score_norm, - whiten_rewards=finetuning_args.ppo_whiten_rewards, - accelerator_kwargs={"step_scheduler_with_optimizer": False}, - log_with=training_args.report_to[0] if training_args.report_to is not None else None, - project_kwargs={"logging_dir": training_args.logging_dir}, - ) - - # Create optimizer and scheduler - if training_args.max_steps > 0: - num_training_steps = training_args.max_steps - else: - total_train_batch_size = backward_batch_size * finetuning_args.ppo_buffer_size * training_args.world_size - num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) - - optimizer = create_custom_optimzer(model, training_args, finetuning_args) - if optimizer is None: - optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) - - create_custom_scheduler(training_args, num_training_steps, optimizer) - lr_scheduler = get_scheduler( - training_args.lr_scheduler_type, - optimizer=optimizer, - num_warmup_steps=training_args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - # Initialize our Trainer ppo_trainer = CustomPPOTrainer( model_args=model_args, @@ -89,15 +45,12 @@ def run_ppo( finetuning_args=finetuning_args, generating_args=generating_args, callbacks=callbacks + [FixValueHeadModelCallback()], - reward_model=reward_model, - config=ppo_config, model=model, + reward_model=reward_model, ref_model=ref_model, tokenizer=tokenizer, dataset=dataset, data_collator=data_collator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, ) # Training diff --git a/src/llmtuner/train/utils.py b/src/llmtuner/train/utils.py index 73854a5e73..8f218a78c3 100644 --- a/src/llmtuner/train/utils.py +++ b/src/llmtuner/train/utils.py @@ -70,7 +70,7 @@ def create_modelcard_and_push( def create_ref_model( model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False -) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: +) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]: r""" Creates reference model for PPO/DPO training. Evaluation mode is not supported. @@ -105,7 +105,7 @@ def create_ref_model( def create_reward_model( model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" -) -> "AutoModelForCausalLMWithValueHead": +) -> Optional["AutoModelForCausalLMWithValueHead"]: r""" Creates reward model for PPO training. """