Skip to content

feat: use async vllm engine (only used in unit tests) #418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/grpo_math_1B.yaml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind adding this key to all the grpo recipes under examples/config/recipes/grpo*.yaml?

Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ policy:
stop_token_ids: null
stop_strings: null
vllm_cfg:
async_engine: false
precision: ${policy.precision}
tensor_parallel_size: 1
gpu_memory_utilization: 0.6
Expand Down
40 changes: 33 additions & 7 deletions nemo_rl/distributed/worker_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,50 @@ def get_results(self, worker_group):
worker group each worker belongs to, then selects only the first result from each group.

Args:
worker_group: The RayWorkerGroup that created this bundle
worker_group: The RayWorkerGroup that spawned the futures. The
mapping contained in worker_group.worker_to_tied_group_index
is required for the deduplication path.

Returns:
List of results, deduplicated by tied workers if respect_tied_workers is True
"""
# Basic case: Get all results
all_results = ray.get(self.futures)
from ray._raylet import ObjectRef, ObjectRefGenerator

# Flatten futures into a list of ObjectRefs
object_refs: list[ObjectRef] = []

# Map each ObjectRef back to the originating worker
ref_owner_indices: list[int] = []
has_generator = False

for idx, fut in enumerate(self.futures):
if isinstance(fut, ObjectRefGenerator):
# ray.get cannot be called directly on the generator object – it must be iterated to obtain the individual ObjectRef instances first.
for generated_ref in fut:
object_refs.append(generated_ref)
ref_owner_indices.append(self.used_workers[idx])
has_generator = True
else:
object_refs.append(fut)
ref_owner_indices.append(self.used_workers[idx])

# Retrieve the concrete results.
all_results = ray.get(object_refs)

# If expanded generator was present we are in streaming mode.
# Every ObjectRef now corresponds to a unique, ordered chunk of data
if has_generator:
return all_results

# If we don't need to deduplicate by tied workers, return all results
if not self.respect_tied_workers:
return all_results

if not self.used_workers:
if not ref_owner_indices:
return all_results

# Create tied worker sets based on used workers
# Create tied worker sets based on originating workers of each result
active_tied_workers = {}
for i, worker_idx in enumerate(self.used_workers):
for i, worker_idx in enumerate(ref_owner_indices):
tied_worker_idx = worker_group.worker_to_tied_group_index.get(worker_idx)
if tied_worker_idx is None:
continue
Expand Down
Loading