Skip to content

Commit 99b4cf5

Browse files
[Bugfix] Fix speculative decoding with MLPSpeculator with padded vocabulary (vllm-project#7218)
Signed-off-by: Travis Johnson <[email protected]>
1 parent e02ac55 commit 99b4cf5

File tree

4 files changed

+66
-5
lines changed

4 files changed

+66
-5
lines changed

tests/spec_decode/e2e/test_mlp_correctness.py

+60
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
correctess for the target model outputs.
2020
"""
2121

22+
from unittest.mock import patch
23+
2224
import pytest
2325

26+
from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size
27+
2428
from .conftest import (run_equality_correctness_test,
2529
run_greedy_equality_correctness_test)
2630

@@ -178,6 +182,62 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
178182
force_output_len=True)
179183

180184

185+
@pytest.mark.parametrize(
186+
"common_llm_kwargs",
187+
[{
188+
"block_size": 8,
189+
# 2 for small prompt, 256//8 for generated.
190+
"num_gpu_blocks_override": 2 + 256 // 8,
191+
"max_model_len": (2 + 256 // 8) * 8,
192+
193+
# Skip cuda graph recording for fast test.
194+
"enforce_eager": True,
195+
196+
# Required for spec decode.
197+
"use_v2_block_manager": True,
198+
199+
# Precision
200+
"dtype": PRECISION,
201+
202+
# Main model
203+
"model": MAIN_MODEL,
204+
}])
205+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
206+
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
207+
@pytest.mark.parametrize("test_llm_kwargs", [
208+
{
209+
"speculative_model": SPEC_MODEL,
210+
},
211+
])
212+
@pytest.mark.parametrize(
213+
"output_len",
214+
[
215+
# Use small output len for fast test.
216+
128,
217+
])
218+
@pytest.mark.parametrize("batch_size", [4])
219+
@pytest.mark.parametrize("seed", [1])
220+
def test_mlp_e2e_greedy_correctness_with_padding(baseline_llm_generator,
221+
test_llm_generator,
222+
batch_size: int,
223+
output_len: int):
224+
"""Verify greedy equality when the vocab dimension is padded
225+
"""
226+
227+
# Default pad_to is 64, test model has vocab_size of 32000
228+
def patched_pad_vocab_size(vocab_size, pad_to=None):
229+
return pad_vocab_size(vocab_size, pad_to=32064)
230+
231+
with patch(
232+
"vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size",
233+
patched_pad_vocab_size):
234+
run_greedy_equality_correctness_test(baseline_llm_generator,
235+
test_llm_generator,
236+
batch_size,
237+
max_output_len=output_len,
238+
force_output_len=True)
239+
240+
181241
@pytest.mark.parametrize(
182242
"common_llm_kwargs",
183243
[{

vllm/model_executor/layers/logits_processor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _get_logits(self, hidden_states: torch.Tensor,
9191
logits = tensor_model_parallel_all_gather(logits)
9292
# Remove paddings in vocab (if any).
9393
if logits is not None:
94-
logits = logits[:, :self.org_vocab_size]
94+
logits = logits[..., :self.org_vocab_size]
9595
return logits
9696

9797
def extra_repr(self) -> str:

vllm/model_executor/layers/rejection_sampler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def forward(
7878
# Only perform shape/dtype/device checking in strict mode, as it adds
7979
# overhead.
8080
if self._strict_mode:
81-
self._raise_if_incorrect_input(target_probs, bonus_token_ids,
82-
draft_probs, draft_token_ids)
81+
self._raise_if_incorrect_input(target_probs, draft_token_ids,
82+
bonus_token_ids, draft_probs)
8383

8484
accepted, recovered_token_ids = (
8585
self._batch_modified_rejection_sampling(

vllm/model_executor/models/mlp_speculator.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,14 @@ def generate_proposals(
175175
states.add_(z, alpha=self.emb_weight / self.state_weight)
176176

177177
states = self.activation(self.ln[head_index](states)) # b k d
178-
# TODO: not yet supporting top_k_tokens_per_head
179178
previous_hidden_states = states
179+
# TODO: not yet supporting top_k_tokens_per_head
180+
states = states.flatten(0, 1)
180181

181182
logits = self.logits_processor(self.head[head_index], states,
182183
sampling_metadata)
183184

184-
output = self.sampler(logits.flatten(0, 1), sampling_metadata)
185+
output = self.sampler(logits, sampling_metadata)
185186
last_tokens = output.sampled_token_ids
186187
next_tokens.append(output)
187188

0 commit comments

Comments
 (0)