From 891826d6ad7de03b464e2ab9ab5ea5ac1279e454 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 6 Feb 2025 12:33:30 -0800 Subject: [PATCH 1/3] changes based on r1 v - do generation one at a time. --- trl/trainer/qwen_grpo_trainer.py | 57 ++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 303746643b..1fd11ba3a8 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -18,6 +18,7 @@ from collections import defaultdict from typing import Any, Callable, Optional, Union from unittest.mock import patch +import copy import torch import torch.utils.data @@ -418,40 +419,62 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N else: # Regular generation path with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: - prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config - ) + # Generate N times, each generate one with the temp_generation_config + num_generations = self.generation_config.num_return_sequences + temp_generation_config = copy.deepcopy(self.generation_config) + temp_generation_config.num_return_sequences = 1 + + all_completions = [] + + for i in range(num_generations): + completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config) + all_completions.append(completion) + + # Stack all completions and pad if needed + max_length = max(completion.size(1) for completion in all_completions) + padded_completions = [] + + for completion in all_completions: + if completion.size(1) < max_length: + padding = torch.full((completion.size(0), max_length - completion.size(1)), + self.processing_class.tokenizer.pad_token_id, + dtype=completion.dtype, + device=completion.device) + padded_completion = torch.cat([completion, padding], dim=1) + else: + padded_completion = completion + padded_completions.append(padded_completion) + + # Stack all padded completions + prompt_completion_ids = torch.cat(padded_completions, dim=0) prompt_length = prompt_inputs["input_ids"].size(1) completion_ids = prompt_completion_ids[:, prompt_length:] # Get the per-token log probabilities for the completions for the model and the reference model - def get_per_token_logps(model, input_ids, logits_to_keep): - # NOTE: I had to do major surgery here to make this work. See - # https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L430 for the - # original implementation. - - logits = model(input_ids).logits # Expected shape: [B, L, V] - logits = logits[:, :-1, :] # Shape becomes [B, L-1, V] - drop extra last token's logits. - logits = logits[:, -logits_to_keep:, :] # Select the last `logits_to_keep` tokens. - + def get_per_token_logps(model, input_ids): + logits = model(input_ids).logits # (B, L, V) + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it # Compute the log probabilities for the input tokens. Use a loop to reduce memory peak. per_token_logps = [] - for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]): + for logits_row, input_ids_row in zip(logits, input_ids): log_probs = logits_row.log_softmax(dim=-1) token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1) per_token_logps.append(token_log_prob) return torch.stack(per_token_logps) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - per_token_logps = get_per_token_logps(model, prompt_completion_ids, logits_to_keep) + per_token_logps = get_per_token_logps(model, prompt_completion_ids) + # Get rid of the prompt (-1 because of the shift done in get_per_token_logps) + per_token_logps = per_token_logps[:, prompt_length - 1:] with torch.inference_mode(): if self.ref_model is not None: - ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, logits_to_keep) + ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids) else: with self.accelerator.unwrap_model(model).disable_adapter(): - ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, logits_to_keep) + ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids) + ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:] # Compute the KL divergence between the model and the reference model per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 From e5e84c029e451c1bfdb610c24749f66b65ea7ff7 Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Thu, 6 Feb 2025 15:02:57 -0800 Subject: [PATCH 2/3] support q2vl --- trl/trainer/qwen_grpo_trainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 1fd11ba3a8..339c529750 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -34,7 +34,9 @@ PreTrainedModel, PreTrainedTokenizerBase, Qwen2_5_VLForConditionalGeneration, + Qwen2VLForConditionalGeneration, Qwen2_5_VLProcessor, + Qwen2VLProcessor, Trainer, TrainerCallback, is_wandb_available, @@ -172,9 +174,9 @@ def __init__( # Models # Trained model model_init_kwargs = args.model_init_kwargs or {} - if not isinstance(model, Qwen2_5_VLForConditionalGeneration): + if not isinstance(model, Qwen2_5_VLForConditionalGeneration) and not isinstance(model, Qwen2VLForConditionalGeneration): raise ValueError( - "QwenGRPOTrainer does not support loading from a model ID. Please pass `model` as a `PreTrainedModel` object." + "QwenGRPOTrainer only support passing a Qwen2_5_VLForConditionalGeneration or Qwen2VLForConditionalGeneration object as the `model` argument." ) else: model_id = model.config._name_or_path @@ -205,8 +207,8 @@ def __init__( self.ref_model = None # Processing class - if not isinstance(processing_class, Qwen2_5_VLProcessor): - raise ValueError("`processing_class` must be a `Qwen2_5_VLProcessor` object.") + if not isinstance(processing_class, Qwen2_5_VLProcessor) and not isinstance(processing_class, Qwen2VLProcessor): + raise ValueError("`processing_class` must be a `Qwen2_5_VLProcessor` or `Qwen2VLProcessor` object.") # Reward functions if not isinstance(reward_funcs, list): @@ -502,7 +504,7 @@ def get_per_token_logps(model, input_ids): print( "using pretrained model as reward fn, this is unexpected as Im passing in functions as reward fns" ) - if is_conversational(inputs[0]): + if is_conversational(inputs[0]): messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: From a35366b24ac8b1625d3390f7b0356ff5b931140a Mon Sep 17 00:00:00 2001 From: Sunil Kumar Date: Fri, 7 Feb 2025 15:45:12 -0800 Subject: [PATCH 3/3] ensure everything is uptodate for leo --- trl/trainer/qwen_grpo_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/qwen_grpo_trainer.py b/trl/trainer/qwen_grpo_trainer.py index 339c529750..60ce3340da 100644 --- a/trl/trainer/qwen_grpo_trainer.py +++ b/trl/trainer/qwen_grpo_trainer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import time import os import textwrap import warnings @@ -372,7 +372,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") - + device = self.accelerator.device prompt_inputs, prompts = self.tokenize_and_inject_images(inputs=inputs, processing_class=self.processing_class) prompt_inputs = super()._prepare_inputs(prompt_inputs) @@ -383,7 +383,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N print(f"Truncating prompt from {original_length} to {self.max_prompt_length} tokens") prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :] prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :] - + # Generate completions using either vLLM or regular generation if self.args.use_vllm: print("You are using vLLM for inference. This probably won't work as I fixed this yet to work with VLM")