Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements based on R1-V #2

Open
wants to merge 3 commits into
base: support_vlm
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions trl/trainer/qwen_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import os
import textwrap
import warnings
from collections import defaultdict
from typing import Any, Callable, Optional, Union
from unittest.mock import patch
import copy

import torch
import torch.utils.data
Expand All @@ -33,7 +34,9 @@
PreTrainedModel,
PreTrainedTokenizerBase,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLProcessor,
Qwen2VLProcessor,
Trainer,
TrainerCallback,
is_wandb_available,
Expand Down Expand Up @@ -171,9 +174,9 @@ def __init__(
# Models
# Trained model
model_init_kwargs = args.model_init_kwargs or {}
if not isinstance(model, Qwen2_5_VLForConditionalGeneration):
if not isinstance(model, Qwen2_5_VLForConditionalGeneration) and not isinstance(model, Qwen2VLForConditionalGeneration):
raise ValueError(
"QwenGRPOTrainer does not support loading from a model ID. Please pass `model` as a `PreTrainedModel` object."
"QwenGRPOTrainer only support passing a Qwen2_5_VLForConditionalGeneration or Qwen2VLForConditionalGeneration object as the `model` argument."
)
else:
model_id = model.config._name_or_path
Expand Down Expand Up @@ -204,8 +207,8 @@ def __init__(
self.ref_model = None

# Processing class
if not isinstance(processing_class, Qwen2_5_VLProcessor):
raise ValueError("`processing_class` must be a `Qwen2_5_VLProcessor` object.")
if not isinstance(processing_class, Qwen2_5_VLProcessor) and not isinstance(processing_class, Qwen2VLProcessor):
raise ValueError("`processing_class` must be a `Qwen2_5_VLProcessor` or `Qwen2VLProcessor` object.")

# Reward functions
if not isinstance(reward_funcs, list):
Expand Down Expand Up @@ -369,7 +372,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")

device = self.accelerator.device
prompt_inputs, prompts = self.tokenize_and_inject_images(inputs=inputs, processing_class=self.processing_class)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
Expand All @@ -380,7 +383,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
print(f"Truncating prompt from {original_length} to {self.max_prompt_length} tokens")
prompt_inputs["input_ids"] = prompt_inputs["input_ids"][:, -self.max_prompt_length :]
prompt_inputs["attention_mask"] = prompt_inputs["attention_mask"][:, -self.max_prompt_length :]

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
print("You are using vLLM for inference. This probably won't work as I fixed this yet to work with VLM")
Expand Down Expand Up @@ -418,40 +421,62 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
else:
# Regular generation path
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config
)
# Generate N times, each generate one with the temp_generation_config
num_generations = self.generation_config.num_return_sequences
temp_generation_config = copy.deepcopy(self.generation_config)
temp_generation_config.num_return_sequences = 1

all_completions = []

for i in range(num_generations):
completion = unwrapped_model.generate(**prompt_inputs, generation_config=temp_generation_config)
all_completions.append(completion)

# Stack all completions and pad if needed
max_length = max(completion.size(1) for completion in all_completions)
padded_completions = []

for completion in all_completions:
if completion.size(1) < max_length:
padding = torch.full((completion.size(0), max_length - completion.size(1)),
self.processing_class.tokenizer.pad_token_id,
dtype=completion.dtype,
device=completion.device)
padded_completion = torch.cat([completion, padding], dim=1)
else:
padded_completion = completion
padded_completions.append(padded_completion)

# Stack all padded completions
prompt_completion_ids = torch.cat(padded_completions, dim=0)

prompt_length = prompt_inputs["input_ids"].size(1)
completion_ids = prompt_completion_ids[:, prompt_length:]

# Get the per-token log probabilities for the completions for the model and the reference model
def get_per_token_logps(model, input_ids, logits_to_keep):
# NOTE: I had to do major surgery here to make this work. See
# https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py#L430 for the
# original implementation.

logits = model(input_ids).logits # Expected shape: [B, L, V]
logits = logits[:, :-1, :] # Shape becomes [B, L-1, V] - drop extra last token's logits.
logits = logits[:, -logits_to_keep:, :] # Select the last `logits_to_keep` tokens.

def get_per_token_logps(model, input_ids):
logits = model(input_ids).logits # (B, L, V)
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
input_ids = input_ids[:, 1:] # (B, L-1), exclude the first input ID since we don't have logits for it
# Compute the log probabilities for the input tokens. Use a loop to reduce memory peak.
per_token_logps = []
for logits_row, input_ids_row in zip(logits, input_ids[:, -logits_to_keep:]):
for logits_row, input_ids_row in zip(logits, input_ids):
log_probs = logits_row.log_softmax(dim=-1)
token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
per_token_logps.append(token_log_prob)
return torch.stack(per_token_logps)

logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
per_token_logps = get_per_token_logps(model, prompt_completion_ids, logits_to_keep)
per_token_logps = get_per_token_logps(model, prompt_completion_ids)
# Get rid of the prompt (-1 because of the shift done in get_per_token_logps)
per_token_logps = per_token_logps[:, prompt_length - 1:]

with torch.inference_mode():
if self.ref_model is not None:
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids, logits_to_keep)
ref_per_token_logps = get_per_token_logps(self.ref_model, prompt_completion_ids)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids, logits_to_keep)
ref_per_token_logps = get_per_token_logps(model, prompt_completion_ids)
ref_per_token_logps = ref_per_token_logps[:, prompt_length - 1:]

# Compute the KL divergence between the model and the reference model
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
Expand Down Expand Up @@ -479,7 +504,7 @@ def get_per_token_logps(model, input_ids, logits_to_keep):
print(
"using pretrained model as reward fn, this is unexpected as Im passing in functions as reward fns"
)
if is_conversational(inputs[0]):
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
Expand Down