Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add verifiers support #5

Open
wants to merge 5 commits into
base: support_vllm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trl/environment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

_import_structure = {
"base_environment": ["TextEnvironment", "TextHistory"],
"env_protocol": ["Environment"],
}

if TYPE_CHECKING:
from .base_environment import TextEnvironment, TextHistory
from .env_protocol import Environment
else:
import sys

Expand Down
12 changes: 12 additions & 0 deletions trl/environment/env_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Any, List, Protocol


class Environment(Protocol):
"""
A protocol describing the minimal interface needed for integration
with the trainer. Your environment can run any multi-step logic,
but must ultimately return token sequences akin to selecting token_ids from
vllm.LLM's generate() output. https://docs.vllm.ai/en/stable/api/offline_inference/llm.html
"""

def generate(self, vllm_inputs, processing_class, vlm, sampling_params) -> List[Any]: ...
143 changes: 78 additions & 65 deletions trl/trainer/qwen_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import textwrap
import warnings
from collections import defaultdict
from copy import deepcopy
from typing import Any, Callable, Optional, Sized, Union
from unittest.mock import patch

Expand All @@ -30,7 +31,6 @@
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
GenerationConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
Qwen2_5_VLForConditionalGeneration,
Expand All @@ -43,6 +43,7 @@
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational
from ..environment.env_protocol import Environment
from ..import_utils import is_vllm_available
from ..models import (
create_reference_model,
Expand Down Expand Up @@ -218,7 +219,7 @@ def __init__(
model: PreTrainedModel,
reward_funcs: Union[RewardFunc, list[RewardFunc]],
processing_class: PreTrainedTokenizerBase,
tokenize_and_inject_images: Callable,
env: Environment,
args: GRPOConfig = None,
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
Expand All @@ -228,9 +229,6 @@ def __init__(
peft_config: Optional["PeftConfig"] = None,
shuffle_dataset: bool = True,
):
# Add shuffle_dataset to instance variables
self.shuffle_dataset = shuffle_dataset

# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
Expand Down Expand Up @@ -327,8 +325,6 @@ def __init__(
reward_processing_classes[i] = reward_processing_class
self.reward_processing_classes = reward_processing_classes

self.tokenize_and_inject_images = tokenize_and_inject_images

# Data collator
def data_collator(features): # No data collation is needed in GRPO
return features
Expand All @@ -338,7 +334,7 @@ def data_collator(features): # No data collation is needed in GRPO
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm
print(f"use_vllm: {self.use_vllm}")
self.shuffle_dataset = shuffle_dataset

self.beta = args.beta

Expand Down Expand Up @@ -457,12 +453,9 @@ def data_collator(features): # No data collation is needed in GRPO
# synchronize all processes after vLLM has been fully initialized.
self.accelerator.wait_for_everyone()
else:
self.generation_config = GenerationConfig(
max_new_tokens=self.max_completion_length,
do_sample=True,
temperature=args.temperature,
pad_token_id=processing_class.tokenizer.pad_token_id,
)
raise ValueError("use_vllm must be True")

self.env = env

# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
Expand Down Expand Up @@ -575,12 +568,35 @@ def _move_model_to_vllm(self):
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device

prompt_inputs, vllm_inputs, prompts_text, prompts = self.tokenize_and_inject_images(
if not self.env:
raise ValueError("No environment provided. Only supporting envs now. ")

# TODO: This is a hack that we should probably fix.
# without this, each gpu receives different inputs, screwing up the advantage computation.
# Simple synchronization of inputs across processes
if self.accelerator.num_processes > 1:
# Make sure all processes have a non-None value to gather
# Use an empty list for non-main processes
local_inputs = inputs if self.accelerator.process_index == 0 else []

# Gather from all processes using torch.distributed.gather_object
all_inputs = gather_object(local_inputs)

# each process takes the inputs from process 0 as its inputs
inputs = deepcopy(all_inputs)

self.accelerator.wait_for_everyone()

# conversations: list of conversations
# prompts_text: list of prompts as strings
# prompt_inputs: tokenized data (with image tokens injected) that we will use to compute log probs on the base model.
# env_inputs: data in the format our env/vllm expects
conversations, prompts_text, prompt_inputs, env_inputs = self.env.prepare_data(
inputs=inputs, processing_class=self.processing_class
)

# unpack prompt_inputs
prompt_inputs = super()._prepare_inputs(prompt_inputs)

prompt_ids, prompt_mask, pixel_values, image_grid_thw = (
prompt_inputs["input_ids"],
prompt_inputs["attention_mask"],
Expand All @@ -589,70 +605,51 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
)

if self.max_prompt_length is not None:
if self.use_vllm:
raise ValueError(
"max_prompt_length is not supported when using vLLM. Please set it to None if vLLM is used. This is because we don't control tokenization when using vLLM."
)

prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
raise ValueError("max_prompt_length is not supported.")

# Generate completions using either vLLM or regular generation
# Generate completions using vLLM
if self.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompt inputs and use them in a single call in the main process
all_vllm_inputs = gather_object(vllm_inputs)
all_prompts_text = gather_object(prompts_text)
all_env_inputs = gather_object(env_inputs)
all_conversations = gather_object(conversations)

if self.accelerator.is_main_process:
outputs = self.vlm.generate(
all_vllm_inputs,
sampling_params=self.sampling_params,
use_tqdm=False,
)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
if self.env is None:
raise ValueError("No environment provided. Only supporting envs now.")
else:
completion_ids = self.env.generate(
conversations=all_conversations,
vlm_inputs=all_env_inputs,
vlm=self.vlm,
sampling_params=self.sampling_params,
)

else:
completion_ids = [None] * len(all_prompts_text)
completion_ids = [None] * len(all_env_inputs)

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
self.accelerator.process_index * len(inputs),
(self.accelerator.process_index + 1) * len(inputs),
)
completion_ids = completion_ids[process_slice]

# Pad the completions, and concatenate them with the prompts
eos_idx = torch.tensor([len(ids) - 1 for ids in completion_ids], device=device)

# Pad completion_ids to uniform length, mask from last output token (EOS)
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id)
sequence_indices = torch.arange(completion_ids.size(1), device=device).expand(completion_ids.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
prompt_completion_ids = unwrapped_model.generate(
prompt_ids,
attention_mask=prompt_mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,
generation_config=self.generation_config,
)

# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
prompt_ids = prompt_completion_ids[:, :prompt_length]
completion_ids = prompt_completion_ids[:, prompt_length:]

# Mask everything after the first EOS token
is_eos = completion_ids == self.processing_class.tokenizer.eos_token_id
eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
raise ValueError("Attempted to generate with HF. Only supporting vllm now.")

# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B*G, P+C)
Expand Down Expand Up @@ -684,7 +681,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
if is_conversational(inputs[0]):
completions = []
for prompt, completion in zip(prompts, completions_text):
for prompt, completion in zip(conversations, completions_text):
bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
if isinstance(bootstrap, list):
if len(bootstrap) > 1:
Expand All @@ -695,16 +692,16 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
else:
completions = completions_text

rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
rewards_per_func = torch.zeros(len(conversations), len(self.reward_funcs), device=device)
for i, (reward_func, reward_processing_class) in enumerate(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
messages = [{"messages": p + c} for p, c in zip(conversations, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
texts = [p + c for p, c in zip(conversations, completions)]
reward_inputs = reward_processing_class(
texts,
return_tensors="pt",
Expand All @@ -720,13 +717,29 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
reward_kwargs["prompts_text"] = prompts_text
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
output_reward_func = reward_func(prompts=conversations, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
rewards_per_func = gather(rewards_per_func)

# # DEBUG: Verify prompt consistency across completions in each group
# TODO: remove this probably?
# if self.accelerator.is_main_process:
# all_prompts = gather_object(prompts_text)

# if not len(all_prompts) == self.num_generations:
# raise ValueError(
# f"We should have one prompt per generation, but we have {len(all_prompts)} prompts and {self.num_generations} generations"
# )
# if not len(set(all_prompts)) == 1:
# raise ValueError(f"All prompts should be the same. {all_prompts=}")
# print("PASSED PROMPT CONSISTENCY CHECK")

# # Add synchronization point to prevent processes from getting out of sync
# self.accelerator.wait_for_everyone()

# Apply weights to each reward function's output and sum
rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)

Expand All @@ -741,8 +754,8 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s

# Slice to keep only the local part of the data
process_slice = slice(
self.accelerator.process_index * len(prompts),
(self.accelerator.process_index + 1) * len(prompts),
self.accelerator.process_index * len(conversations),
(self.accelerator.process_index + 1) * len(conversations),
)
advantages = advantages[process_slice]

Expand Down