Skip to content

🎲 [GRPO] Shuffle mini batches #3391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.utils import is_peft_available

from trl import GRPOConfig, GRPOTrainer
from trl.trainer.grpo_trainer import RepeatSampler
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict

from .testing_utils import require_vllm

Expand All @@ -33,6 +33,77 @@
from peft import LoraConfig, PeftModel


class SplitTensorDictTester(unittest.TestCase):
def test_split_equal_chunks(self):
x = torch.arange(12).reshape(6, 2)
y = torch.arange(6).reshape(6, 1)
tensor_dict = {"x": x, "y": y}

result = split_tensor_dict(tensor_dict, 3)

expected_x_chunks = torch.chunk(x, 3, dim=0)
expected_y_chunks = torch.chunk(y, 3, dim=0)
self.assertEqual(len(result), 3)
for i in range(3):
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i]))

def test_with_none_tensor(self):
x = torch.arange(12).reshape(6, 2)
tensor_dict = {"x": x, "y": None}

result = split_tensor_dict(tensor_dict, 2)

expected_x_chunks = torch.chunk(x, 2, dim=0)
self.assertEqual(len(result), 2)
for i in range(2):
self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i]))
self.assertIsNone(result[i]["y"])


class ShuffleTensorDictTester(unittest.TestCase):
def test_shuffle_preserves_shape(self):
x = torch.arange(6).reshape(3, 2)
y = torch.arange(3).reshape(3, 1)
tensor_dict = {"x": x.clone(), "y": y.clone()}

shuffled = shuffle_tensor_dict(tensor_dict)

self.assertEqual(shuffled["x"].shape, x.shape)
self.assertEqual(shuffled["y"].shape, y.shape)

def test_shuffle_consistent_across_tensors(self):
# Use known patterns to check alignment
x = torch.tensor([[10, 11], [20, 21], [30, 31]])
y = torch.tensor([[1], [2], [3]])
tensor_dict = {"x": x.clone(), "y": y.clone()}

shuffled = shuffle_tensor_dict(tensor_dict)

# Build a reverse map from shuffled x rows to y values
for i in range(3):
x_row = shuffled["x"][i]
y_val = shuffled["y"][i].item()

if torch.equal(x_row, torch.tensor([10, 11])):
self.assertEqual(y_val, 1)
elif torch.equal(x_row, torch.tensor([20, 21])):
self.assertEqual(y_val, 2)
elif torch.equal(x_row, torch.tensor([30, 31])):
self.assertEqual(y_val, 3)
else:
self.fail("Unexpected x row in shuffled output.")

def test_none_tensor_remains_none(self):
x = torch.arange(6).reshape(3, 2)
tensor_dict = {"x": x.clone(), "y": None}

shuffled = shuffle_tensor_dict(tensor_dict)

self.assertIsNone(shuffled["y"])
self.assertEqual(shuffled["x"].shape, x.shape)


class RepeatRandomSamplerTester(unittest.TestCase):
def test_sampler(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
Expand Down
25 changes: 24 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,28 @@ def split_tensor_dict(
]


def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]:
"""
Shuffles a dictionary of tensors along the first dimension in unison.

Example:
>>> x = torch.arange(6).reshape(3, 2)
>>> y = torch.arange(3).reshape(3, 1)
>>> tensor_dict = {"x": x, "y": y}
>>> shuffle_tensor_dict(tensor_dict)
{'x': tensor([[2, 3],
[0, 1],
[4, 5]]),
'y': tensor([[1],
[0],
[2]])}
"""
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
batch_size = first_tensor.shape[0]
permutation = torch.randperm(batch_size)
return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()}


def nanmin(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
Expand Down Expand Up @@ -945,6 +967,7 @@ def _prepare_inputs(
if self._step % generate_every == 0 or self._buffered_inputs is None:
# self._buffered_inputs=None can occur when resuming from a checkpoint
generation_batch = self._generate_and_score_completions(generation_batch)
generation_batch = shuffle_tensor_dict(generation_batch)
self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation)
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
self._step += 1
Expand Down Expand Up @@ -1090,7 +1113,7 @@ def _generate_and_score_completions(

with torch.no_grad():
# When using num_iterations == 1 and steps_per_generation <= gradient_accumulation_steps
# old_per_token_logps == per_token_logps, so we can skip its computation here, and use
# old_per_token_logps == per_token_logps, so we can skip it's computation here, and use
# per_token_logps.detach() instead.
if self.num_iterations > 1 or self.args.steps_per_generation > self.args.gradient_accumulation_steps:
old_per_token_logps = self._get_per_token_logps(
Expand Down
Loading