Skip to content

Commit 1856aff

Browse files
authored
[Spec Decoding] Streamline batch expansion tensor manipulation (vllm-project#7851)
1 parent 70c094a commit 1856aff

File tree

5 files changed

+118
-125
lines changed

5 files changed

+118
-125
lines changed

tests/spec_decode/test_utils.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ def fake_sequence_group_metadata():
5555

5656
def test_filter_zero_length_proposals(fake_sequence_group_metadata):
5757
proposal_lens = [0, 1, 0]
58-
filtered_groups, indices = split_batch_by_proposal_len(
59-
fake_sequence_group_metadata,
60-
proposal_lens,
61-
select_proposal_len_zero=True)
58+
_, (filtered_groups,
59+
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
60+
proposal_lens)
6261

6362
expected_groups = [
6463
fake_sequence_group_metadata[0], fake_sequence_group_metadata[2]
@@ -71,10 +70,9 @@ def test_filter_zero_length_proposals(fake_sequence_group_metadata):
7170

7271
def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
7372
proposal_lens = [0, 1, 2]
74-
filtered_groups, indices = split_batch_by_proposal_len(
75-
fake_sequence_group_metadata,
76-
proposal_lens,
77-
select_proposal_len_zero=False)
73+
(filtered_groups,
74+
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
75+
proposal_lens)
7876

7977
expected_groups = [
8078
fake_sequence_group_metadata[1], fake_sequence_group_metadata[2]
@@ -86,30 +84,27 @@ def test_filter_non_zero_length_proposals(fake_sequence_group_metadata):
8684

8785

8886
def test_empty_inputs():
89-
filtered_groups, indices = split_batch_by_proposal_len(
90-
[], [], select_proposal_len_zero=True)
87+
_, (filtered_groups, indices) = split_batch_by_proposal_len([], [])
9188

9289
assert filtered_groups == []
9390
assert indices == []
9491

9592

9693
def test_all_zero_with_non_zero_filter(fake_sequence_group_metadata):
9794
proposal_lens = [0, 0, 0]
98-
filtered_groups, indices = split_batch_by_proposal_len(
99-
fake_sequence_group_metadata,
100-
proposal_lens,
101-
select_proposal_len_zero=False)
95+
(filtered_groups,
96+
indices), _ = split_batch_by_proposal_len(fake_sequence_group_metadata,
97+
proposal_lens)
10298

10399
assert filtered_groups == []
104100
assert indices == []
105101

106102

107103
def test_all_non_zero_with_zero_filter(fake_sequence_group_metadata):
108104
proposal_lens = [1, 1, 1]
109-
filtered_groups, indices = split_batch_by_proposal_len(
110-
fake_sequence_group_metadata,
111-
proposal_lens,
112-
select_proposal_len_zero=True)
105+
_, (filtered_groups,
106+
indices) = split_batch_by_proposal_len(fake_sequence_group_metadata,
107+
proposal_lens)
113108

114109
assert filtered_groups == []
115110
assert indices == []

vllm/spec_decode/batch_expansion.py

+79-64
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
get_all_seq_ids)
1111
from vllm.spec_decode.interfaces import (SpeculativeProposals,
1212
SpeculativeScorer, SpeculativeScores)
13-
from vllm.spec_decode.util import (nvtx_range, sampler_output_to_torch,
14-
split_batch_by_proposal_len)
13+
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
1514
from vllm.worker.worker_base import WorkerBase
1615

1716
SeqId = int
@@ -88,17 +87,25 @@ def score_proposals(
8887
assert len(target_sampler_output) == 1, "expected single-step output"
8988
target_sampler_output = target_sampler_output[0]
9089

91-
(all_tokens, all_probs, spec_logprobs,
92-
all_hidden_states) = self._contract_batch(
93-
contracted_bs=len(execute_model_req.seq_group_metadata_list),
94-
target_sampler_output=target_sampler_output,
95-
proposals=proposals,
96-
num_scoring_tokens=num_scoring_tokens,
97-
non_spec_indices=non_spec_indices,
98-
spec_indices=spec_indices,
99-
k=execute_model_req.num_lookahead_slots,
100-
)
101-
90+
if not non_spec_indices:
91+
# All sequence groups in batch have spec decoding enabled
92+
contracted = self._contract_batch_all_spec(
93+
target_sampler_output=target_sampler_output,
94+
proposals=proposals,
95+
)
96+
else:
97+
# Batch has a mix of spec decode enabled and disabled seq groups
98+
contracted = self._contract_batch(
99+
contracted_bs=len(execute_model_req.seq_group_metadata_list),
100+
target_sampler_output=target_sampler_output,
101+
proposals=proposals,
102+
num_scoring_tokens=num_scoring_tokens,
103+
non_spec_indices=non_spec_indices,
104+
spec_indices=spec_indices,
105+
k=execute_model_req.num_lookahead_slots,
106+
)
107+
108+
all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted
102109
return SpeculativeScores(
103110
probs=all_probs,
104111
token_ids=all_tokens,
@@ -121,14 +128,9 @@ def _expand_batch(
121128
# proposal len. This adds some complexity (splitting the batch into spec
122129
# and non spec sequences) and should be removed in the future. It can be
123130
# done by supporting per-sequence proposal lens.
124-
spec_seqs, spec_indices = split_batch_by_proposal_len(
125-
seq_group_metadata_list,
126-
proposal_lens_list,
127-
select_proposal_len_zero=False)
128-
non_spec_seqs, non_spec_indices = split_batch_by_proposal_len(
129-
seq_group_metadata_list,
130-
proposal_lens_list,
131-
select_proposal_len_zero=True)
131+
(spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \
132+
split_batch_by_proposal_len(
133+
seq_group_metadata_list, proposal_lens_list)
132134

133135
target_seq_group_metadata_list = self._create_scoring_model_input(
134136
seq_group_metadata_list=spec_seqs,
@@ -171,7 +173,7 @@ def _contract_batch(
171173
# The number of tokens in the expanded batch used for speculation is
172174
# equal to the total expanded batch size minus the number of samples for
173175
# non-speculative sequences.
174-
non_spec_expanded_bs, _ = non_spec_target_token_ids.shape
176+
non_spec_expanded_bs = len(non_spec_target_token_ids)
175177
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
176178

177179
target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1)
@@ -181,7 +183,7 @@ def _contract_batch(
181183

182184
if target_hidden_states is not None:
183185
target_hidden_states = target_hidden_states.reshape(
184-
spec_expanded_bs, k + 1, target_hidden_states.shape[-1])
186+
*target_token_ids.shape, target_hidden_states.shape[-1])
185187

186188
all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1),
187189
fill_value=-1)
@@ -196,24 +198,58 @@ def _contract_batch(
196198
all_hidden_states = None
197199

198200
if non_spec_indices:
199-
all_tokens[non_spec_indices, :1] = non_spec_target_token_ids
200-
all_probs[non_spec_indices, :1, :] = non_spec_target_probs
201-
all_logprobs[non_spec_indices, :1, :] = non_spec_target_logprobs
202-
201+
all_tokens[non_spec_indices, :1] = \
202+
non_spec_target_token_ids.unsqueeze(1)
203+
all_probs[non_spec_indices, :1, :] = \
204+
non_spec_target_probs.unsqueeze(1)
205+
all_logprobs[non_spec_indices, :1, :] = \
206+
non_spec_target_logprobs.unsqueeze(1)
203207
if all_hidden_states is not None:
204-
all_hidden_states[
205-
non_spec_indices, :1, :] = non_spec_target_hidden_states
208+
assert non_spec_target_hidden_states is not None
209+
all_hidden_states[non_spec_indices, :1, :] = \
210+
non_spec_target_hidden_states.unsqueeze(1)
206211

207212
if spec_indices:
208213
all_tokens[spec_indices] = target_token_ids
209214
all_probs[spec_indices] = target_probs
210215
all_logprobs[spec_indices] = target_logprobs
211-
212216
if all_hidden_states is not None:
213217
all_hidden_states[spec_indices] = target_hidden_states
214218

215219
return all_tokens, all_probs, all_logprobs, all_hidden_states
216220

221+
def _contract_batch_all_spec(
222+
self,
223+
target_sampler_output: SamplerOutput,
224+
proposals: SpeculativeProposals,
225+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
226+
Optional[torch.Tensor]]:
227+
"""Contract the expanded batch back into its original size.
228+
This maps the scores of speculative tokens back to their original
229+
sequences.
230+
231+
It assumes all sequences in the batch were previously expanded.
232+
"""
233+
234+
# Map distinct sequences used to score each token
235+
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
236+
contracted_bs, k = proposals.proposal_token_ids.shape
237+
238+
# Reshape tensors to original batch size
239+
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
240+
contracted_bs, k + 1)
241+
target_probs = target_sampler_output.sampled_token_probs.reshape(
242+
*target_token_ids.shape, self._vocab_size)
243+
target_logprobs = target_sampler_output.logprobs.reshape(
244+
target_probs.shape)
245+
target_hidden_states = target_sampler_output.hidden_states
246+
if target_hidden_states is not None:
247+
target_hidden_states = target_hidden_states.reshape(
248+
*target_token_ids.shape, target_hidden_states.shape[-1])
249+
250+
return (target_token_ids, target_probs, target_logprobs,
251+
target_hidden_states)
252+
217253
def _create_scoring_model_input(
218254
self,
219255
seq_group_metadata_list: List[SequenceGroupMetadata],
@@ -345,8 +381,9 @@ def _create_single_target_seq_group_metadata(
345381
token_chunk_size=1,
346382
)
347383

384+
@staticmethod
348385
def _split_scoring_output(
349-
self, sampler_output: SamplerOutput, num_scoring_tokens: int
386+
sampler_output: SamplerOutput, num_scoring_tokens: int
350387
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
351388
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
352389
torch.Tensor, Optional[torch.Tensor]]:
@@ -361,10 +398,9 @@ def _split_scoring_output(
361398
#
362399
# First samples are from speculative scoring, latter samples are non-
363400
# speculative samples.
364-
split_sizes = [
365-
num_scoring_tokens,
366-
sampler_output.sampled_token_ids.numel() - num_scoring_tokens
367-
]
401+
split_sizes = (num_scoring_tokens,
402+
sampler_output.sampled_token_ids.numel() -
403+
num_scoring_tokens)
368404
(spec_probs, non_spec_probs
369405
) = sampler_output.sampled_token_probs.split(split_sizes)
370406
(spec_sampled_tokens, non_spec_sampled_tokens
@@ -382,32 +418,13 @@ def _split_scoring_output(
382418
else:
383419
spec_hidden_states, non_spec_hidden_states = None, None
384420

385-
# Convert scores to tensors.
386-
sampler_output.sampled_token_probs = spec_probs
387-
sampler_output.sampled_token_ids = spec_sampled_tokens
388-
sampler_output.logprobs = spec_logprobs
389-
sampler_output.hidden_states = spec_hidden_states
390-
(target_token_ids, target_probs, target_logprobs,
391-
target_hidden_states) = sampler_output_to_torch([sampler_output],
392-
True)
393-
394-
# Convert non-speculative output tokens to tensors.
395-
sampler_output.sampled_token_probs = non_spec_probs
396-
sampler_output.sampled_token_ids = non_spec_sampled_tokens
397-
sampler_output.logprobs = non_spec_logprobs
398-
sampler_output.hidden_states = non_spec_hidden_states
399-
(non_spec_target_token_ids, non_spec_target_probs,
400-
non_spec_target_logprobs,
401-
non_spec_target_hidden_states) = sampler_output_to_torch(
402-
[sampler_output], True)
403-
404-
return (target_token_ids, target_probs, target_logprobs,
405-
target_hidden_states, non_spec_target_token_ids,
406-
non_spec_target_probs, non_spec_target_logprobs,
407-
non_spec_target_hidden_states)
421+
return (spec_sampled_tokens, spec_probs, spec_logprobs,
422+
spec_hidden_states, non_spec_sampled_tokens, non_spec_probs,
423+
non_spec_logprobs, non_spec_hidden_states)
408424

425+
@staticmethod
409426
def _create_target_seq_id_iterator(
410-
self, seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
427+
seq_ids: List[SeqId]) -> Iterator[TargetSeqId]:
411428
"""Create an iterator for creating target sequence ids.
412429
Target sequence ids are distinct from sequence ids because we create a
413430
distinct target sequence id for each proposal token to be scored.
@@ -417,8 +434,8 @@ def _create_target_seq_id_iterator(
417434
"""
418435
return count(start=max(seq_ids) + 1)
419436

437+
@staticmethod
420438
def _get_token_ids_to_score(
421-
self,
422439
full_spec_token_ids: List[TokenId] # shape: [k]
423440
) -> List[List[TokenId]]:
424441
"""Given an int tensor of proposal token ids, return a list of
@@ -439,8 +456,6 @@ def _get_token_ids_to_score(
439456
empty_token_ids: List[TokenId] = []
440457

441458
token_ids_to_score = [empty_token_ids]
442-
token_ids_to_score.extend([
443-
full_spec_token_ids[:i + 1]
444-
for i in range(len(full_spec_token_ids))
445-
])
459+
token_ids_to_score.extend(full_spec_token_ids[:i + 1]
460+
for i in range(len(full_spec_token_ids)))
446461
return token_ids_to_score

vllm/spec_decode/spec_decode_worker.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,13 @@ def execute_model(
365365
# used during the prefill phase.
366366
# 2. Auto-disable enabled: The running queue size exceeds
367367
# the specified threshold.
368-
# 3. No request: There are no requests in the batch.
368+
# 3. No request: There are no requests in the batch, or
369+
# none of the requests in the batch have spec decoding enabled.
369370
# In any of these cases, the proposer and scorer workers
370371
# are called normally.
371-
no_spec = num_lookahead_slots == 0 or len(
372-
execute_model_req.seq_group_metadata_list
373-
) == 0 or disable_all_speculation
372+
no_spec = num_lookahead_slots == 0 or disable_all_speculation or all(
373+
sgm.num_speculative_tokens == 0
374+
for sgm in execute_model_req.seq_group_metadata_list)
374375

375376
# Broadcast how many lookahead slots are scheduled for this step, and
376377
# whether all speculation is disabled, to all non-driver workers.
@@ -415,10 +416,8 @@ def _should_disable_all_speculation(
415416
self, execute_model_req: ExecuteModelRequest) -> bool:
416417
# When the batch size is too large, disable speculative decoding
417418
# to stop trading off throughput for latency.
418-
disable_all_speculation = (execute_model_req.running_queue_size >=
419-
self.disable_by_batch_size)
420-
421-
return disable_all_speculation
419+
return (execute_model_req.running_queue_size >=
420+
self.disable_by_batch_size)
422421

423422
def _maybe_disable_speculative_tokens(
424423
self, disable_all_speculation: bool,
@@ -621,14 +620,8 @@ def _verify_tokens(
621620
# proposal len. This adds some complexity (splitting the batch into spec
622621
# and non spec sequences) and should be removed in the future. It can be
623622
# done by supporting per-sequence proposal lens.
624-
_, spec_indices = split_batch_by_proposal_len(
625-
seq_group_metadata_list,
626-
proposal_lens_list,
627-
select_proposal_len_zero=False)
628-
_, non_spec_indices = split_batch_by_proposal_len(
629-
seq_group_metadata_list,
630-
proposal_lens_list,
631-
select_proposal_len_zero=True)
623+
(_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len(
624+
seq_group_metadata_list, proposal_lens_list)
632625
original_indices = spec_indices + non_spec_indices
633626

634627
# Get probabilities of target model, excluding bonus token.

vllm/spec_decode/top1_proposer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def _split_by_proposal_len(
138138

139139
# Currently only proposal lens of 0 or the global batch proposal len
140140
# are supported.
141-
# If max_proposal_len is defined, then we shall no exceed this
141+
# If max_proposal_len is defined, then we shall not exceed this
142142
# quota for nonzero_proposal
143143
new_k = 0
144144
if (self.max_proposal_len is None

0 commit comments

Comments
 (0)