Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support vllm #6

Open
wants to merge 3 commits into
base: new_base_trainer
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions trl/trainer/qwen_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,9 +338,7 @@ def data_collator(features): # No data collation is needed in GRPO
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.use_vllm = args.use_vllm

if self.use_vllm:
raise ValueError("vLLM not supported yet.")
print(f"use_vllm: {self.use_vllm}")

self.beta = args.beta

Expand Down Expand Up @@ -393,7 +391,6 @@ def data_collator(features): # No data collation is needed in GRPO
set_seed(args.seed, device_specific=True)

if self.use_vllm:
raise ValueError("vLLM not supported yet.")
if not is_vllm_available():
raise ImportError(
"vLLM is not available and `use_vllm` is set to True. Please install vLLM with "
Expand All @@ -404,9 +401,12 @@ def data_collator(features): # No data collation is needed in GRPO
vllm_device = self.args.vllm_device
if vllm_device == "auto":
if torch.cuda.device_count() == 1:
print("Only one GPU available, sharing it between vLLM and training.")
vllm_device = "cuda:0" # particular case when training with onyl 1 GPU: share it
else:
vllm_device = f"cuda:{self.accelerator.num_processes}" # take the next GPU idx
print(f"Using GPU {vllm_device} for vLLM.")

# Check that the requested device is available
if vllm_device.split(":")[0] == "cuda" and int(vllm_device.split(":")[1]) >= torch.cuda.device_count():
raise ValueError(
Expand All @@ -432,7 +432,7 @@ def data_collator(features): # No data collation is needed in GRPO
return_value=None,
)
with world_size_patch, profiling_patch:
self.llm = LLM(
self.vlm = LLM(
model=model.name_or_path,
device=vllm_device,
gpu_memory_utilization=self.args.vllm_gpu_memory_utilization,
Expand All @@ -442,6 +442,8 @@ def data_collator(features): # No data collation is needed in GRPO
# This is particularly useful here because we generate completions from the same prompts.
enable_prefix_caching=True,
max_model_len=self.args.vllm_max_model_len,
# Setting this to 1 as we only have one image per prompt for now. Setting it longer requires more resources, which is wasteful until we need it.
limit_mm_per_prompt={"image": 1, "video": 0},
)
self.sampling_params = SamplingParams(
temperature=args.temperature,
Expand Down Expand Up @@ -543,7 +545,6 @@ def _get_per_token_logps(
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens

def _move_model_to_vllm(self):
raise ValueError("vLLM not supported yet.")
with unwrap_model_for_generation(
self.model,
self.accelerator,
Expand All @@ -568,13 +569,13 @@ def _move_model_to_vllm(self):
else:
state_dict = unwrapped_model.state_dict()
if self.accelerator.is_main_process:
llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
llm_model.load_weights(state_dict.items())
vlm_model = self.vlm.llm_engine.model_executor.driver_worker.model_runner.model
vlm_model.load_weights(state_dict.items())

def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
device = self.accelerator.device

prompt_inputs, prompts_text, prompts = self.tokenize_and_inject_images(
prompt_inputs, vllm_inputs, prompts_text, prompts = self.tokenize_and_inject_images(
inputs=inputs, processing_class=self.processing_class
)

Expand All @@ -588,28 +589,35 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s
)

if self.max_prompt_length is not None:
if self.use_vllm:
raise ValueError(
"max_prompt_length is not supported when using vLLM. Please set it to None if vLLM is used. This is because we don't control tokenization when using vLLM."
)

prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]

# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
raise ValueError("vLLM not supported yet.")
if self.use_vllm:
# First, have main process load weights if needed
if self.state.global_step != self._last_loaded_step:
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step

# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
# Generate completions using vLLM: gather all prompt inputs and use them in a single call in the main process
all_vllm_inputs = gather_object(vllm_inputs)
all_prompts_text = gather_object(prompts_text)

if self.accelerator.is_main_process:
outputs = self.llm.generate(
all_prompts_text,
outputs = self.vlm.generate(
all_vllm_inputs,
sampling_params=self.sampling_params,
use_tqdm=False,
)
completion_ids = [out.token_ids for completions in outputs for out in completions.outputs]
else:
completion_ids = [None] * len(all_prompts_text)

# Broadcast the completions from the main process to all processes, ensuring each process receives its
# corresponding slice.
completion_ids = broadcast_object_list(completion_ids, from_process=0)
Expand All @@ -621,7 +629,7 @@ def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[s

# Pad the completions, and concatenate them with the prompts
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
completion_ids = pad(completion_ids, padding_value=self.processing_class.tokenizer.pad_token_id)
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
else:
# Regular generation path
Expand Down