10
10
get_all_seq_ids )
11
11
from vllm .spec_decode .interfaces import (SpeculativeProposals ,
12
12
SpeculativeScorer , SpeculativeScores )
13
- from vllm .spec_decode .util import (nvtx_range , sampler_output_to_torch ,
14
- split_batch_by_proposal_len )
13
+ from vllm .spec_decode .util import nvtx_range , split_batch_by_proposal_len
15
14
from vllm .worker .worker_base import WorkerBase
16
15
17
16
SeqId = int
@@ -88,17 +87,25 @@ def score_proposals(
88
87
assert len (target_sampler_output ) == 1 , "expected single-step output"
89
88
target_sampler_output = target_sampler_output [0 ]
90
89
91
- (all_tokens , all_probs , spec_logprobs ,
92
- all_hidden_states ) = self ._contract_batch (
93
- contracted_bs = len (execute_model_req .seq_group_metadata_list ),
94
- target_sampler_output = target_sampler_output ,
95
- proposals = proposals ,
96
- num_scoring_tokens = num_scoring_tokens ,
97
- non_spec_indices = non_spec_indices ,
98
- spec_indices = spec_indices ,
99
- k = execute_model_req .num_lookahead_slots ,
100
- )
101
-
90
+ if not non_spec_indices :
91
+ # All sequence groups in batch have spec decoding enabled
92
+ contracted = self ._contract_batch_all_spec (
93
+ target_sampler_output = target_sampler_output ,
94
+ proposals = proposals ,
95
+ )
96
+ else :
97
+ # Batch has a mix of spec decode enabled and disabled seq groups
98
+ contracted = self ._contract_batch (
99
+ contracted_bs = len (execute_model_req .seq_group_metadata_list ),
100
+ target_sampler_output = target_sampler_output ,
101
+ proposals = proposals ,
102
+ num_scoring_tokens = num_scoring_tokens ,
103
+ non_spec_indices = non_spec_indices ,
104
+ spec_indices = spec_indices ,
105
+ k = execute_model_req .num_lookahead_slots ,
106
+ )
107
+
108
+ all_tokens , all_probs , spec_logprobs , all_hidden_states = contracted
102
109
return SpeculativeScores (
103
110
probs = all_probs ,
104
111
token_ids = all_tokens ,
@@ -121,14 +128,9 @@ def _expand_batch(
121
128
# proposal len. This adds some complexity (splitting the batch into spec
122
129
# and non spec sequences) and should be removed in the future. It can be
123
130
# done by supporting per-sequence proposal lens.
124
- spec_seqs , spec_indices = split_batch_by_proposal_len (
125
- seq_group_metadata_list ,
126
- proposal_lens_list ,
127
- select_proposal_len_zero = False )
128
- non_spec_seqs , non_spec_indices = split_batch_by_proposal_len (
129
- seq_group_metadata_list ,
130
- proposal_lens_list ,
131
- select_proposal_len_zero = True )
131
+ (spec_seqs , spec_indices ), (non_spec_seqs , non_spec_indices ) = \
132
+ split_batch_by_proposal_len (
133
+ seq_group_metadata_list , proposal_lens_list )
132
134
133
135
target_seq_group_metadata_list = self ._create_scoring_model_input (
134
136
seq_group_metadata_list = spec_seqs ,
@@ -171,7 +173,7 @@ def _contract_batch(
171
173
# The number of tokens in the expanded batch used for speculation is
172
174
# equal to the total expanded batch size minus the number of samples for
173
175
# non-speculative sequences.
174
- non_spec_expanded_bs , _ = non_spec_target_token_ids . shape
176
+ non_spec_expanded_bs = len ( non_spec_target_token_ids )
175
177
spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs
176
178
177
179
target_token_ids = target_token_ids .reshape (spec_expanded_bs , k + 1 )
@@ -181,7 +183,7 @@ def _contract_batch(
181
183
182
184
if target_hidden_states is not None :
183
185
target_hidden_states = target_hidden_states .reshape (
184
- spec_expanded_bs , k + 1 , target_hidden_states .shape [- 1 ])
186
+ * target_token_ids . shape , target_hidden_states .shape [- 1 ])
185
187
186
188
all_tokens = target_token_ids .new_full (size = (contracted_bs , k + 1 ),
187
189
fill_value = - 1 )
@@ -196,24 +198,58 @@ def _contract_batch(
196
198
all_hidden_states = None
197
199
198
200
if non_spec_indices :
199
- all_tokens [non_spec_indices , :1 ] = non_spec_target_token_ids
200
- all_probs [non_spec_indices , :1 , :] = non_spec_target_probs
201
- all_logprobs [non_spec_indices , :1 , :] = non_spec_target_logprobs
202
-
201
+ all_tokens [non_spec_indices , :1 ] = \
202
+ non_spec_target_token_ids .unsqueeze (1 )
203
+ all_probs [non_spec_indices , :1 , :] = \
204
+ non_spec_target_probs .unsqueeze (1 )
205
+ all_logprobs [non_spec_indices , :1 , :] = \
206
+ non_spec_target_logprobs .unsqueeze (1 )
203
207
if all_hidden_states is not None :
204
- all_hidden_states [
205
- non_spec_indices , :1 , :] = non_spec_target_hidden_states
208
+ assert non_spec_target_hidden_states is not None
209
+ all_hidden_states [non_spec_indices , :1 , :] = \
210
+ non_spec_target_hidden_states .unsqueeze (1 )
206
211
207
212
if spec_indices :
208
213
all_tokens [spec_indices ] = target_token_ids
209
214
all_probs [spec_indices ] = target_probs
210
215
all_logprobs [spec_indices ] = target_logprobs
211
-
212
216
if all_hidden_states is not None :
213
217
all_hidden_states [spec_indices ] = target_hidden_states
214
218
215
219
return all_tokens , all_probs , all_logprobs , all_hidden_states
216
220
221
+ def _contract_batch_all_spec (
222
+ self ,
223
+ target_sampler_output : SamplerOutput ,
224
+ proposals : SpeculativeProposals ,
225
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ,
226
+ Optional [torch .Tensor ]]:
227
+ """Contract the expanded batch back into its original size.
228
+ This maps the scores of speculative tokens back to their original
229
+ sequences.
230
+
231
+ It assumes all sequences in the batch were previously expanded.
232
+ """
233
+
234
+ # Map distinct sequences used to score each token
235
+ # of shape [batch_size * k + 1] back to [batch_size, k + 1].
236
+ contracted_bs , k = proposals .proposal_token_ids .shape
237
+
238
+ # Reshape tensors to original batch size
239
+ target_token_ids = target_sampler_output .sampled_token_ids .reshape (
240
+ contracted_bs , k + 1 )
241
+ target_probs = target_sampler_output .sampled_token_probs .reshape (
242
+ * target_token_ids .shape , self ._vocab_size )
243
+ target_logprobs = target_sampler_output .logprobs .reshape (
244
+ target_probs .shape )
245
+ target_hidden_states = target_sampler_output .hidden_states
246
+ if target_hidden_states is not None :
247
+ target_hidden_states = target_hidden_states .reshape (
248
+ * target_token_ids .shape , target_hidden_states .shape [- 1 ])
249
+
250
+ return (target_token_ids , target_probs , target_logprobs ,
251
+ target_hidden_states )
252
+
217
253
def _create_scoring_model_input (
218
254
self ,
219
255
seq_group_metadata_list : List [SequenceGroupMetadata ],
@@ -345,8 +381,9 @@ def _create_single_target_seq_group_metadata(
345
381
token_chunk_size = 1 ,
346
382
)
347
383
384
+ @staticmethod
348
385
def _split_scoring_output (
349
- self , sampler_output : SamplerOutput , num_scoring_tokens : int
386
+ sampler_output : SamplerOutput , num_scoring_tokens : int
350
387
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ,
351
388
Optional [torch .Tensor ], torch .Tensor , torch .Tensor ,
352
389
torch .Tensor , Optional [torch .Tensor ]]:
@@ -361,10 +398,9 @@ def _split_scoring_output(
361
398
#
362
399
# First samples are from speculative scoring, latter samples are non-
363
400
# speculative samples.
364
- split_sizes = [
365
- num_scoring_tokens ,
366
- sampler_output .sampled_token_ids .numel () - num_scoring_tokens
367
- ]
401
+ split_sizes = (num_scoring_tokens ,
402
+ sampler_output .sampled_token_ids .numel () -
403
+ num_scoring_tokens )
368
404
(spec_probs , non_spec_probs
369
405
) = sampler_output .sampled_token_probs .split (split_sizes )
370
406
(spec_sampled_tokens , non_spec_sampled_tokens
@@ -382,32 +418,13 @@ def _split_scoring_output(
382
418
else :
383
419
spec_hidden_states , non_spec_hidden_states = None , None
384
420
385
- # Convert scores to tensors.
386
- sampler_output .sampled_token_probs = spec_probs
387
- sampler_output .sampled_token_ids = spec_sampled_tokens
388
- sampler_output .logprobs = spec_logprobs
389
- sampler_output .hidden_states = spec_hidden_states
390
- (target_token_ids , target_probs , target_logprobs ,
391
- target_hidden_states ) = sampler_output_to_torch ([sampler_output ],
392
- True )
393
-
394
- # Convert non-speculative output tokens to tensors.
395
- sampler_output .sampled_token_probs = non_spec_probs
396
- sampler_output .sampled_token_ids = non_spec_sampled_tokens
397
- sampler_output .logprobs = non_spec_logprobs
398
- sampler_output .hidden_states = non_spec_hidden_states
399
- (non_spec_target_token_ids , non_spec_target_probs ,
400
- non_spec_target_logprobs ,
401
- non_spec_target_hidden_states ) = sampler_output_to_torch (
402
- [sampler_output ], True )
403
-
404
- return (target_token_ids , target_probs , target_logprobs ,
405
- target_hidden_states , non_spec_target_token_ids ,
406
- non_spec_target_probs , non_spec_target_logprobs ,
407
- non_spec_target_hidden_states )
421
+ return (spec_sampled_tokens , spec_probs , spec_logprobs ,
422
+ spec_hidden_states , non_spec_sampled_tokens , non_spec_probs ,
423
+ non_spec_logprobs , non_spec_hidden_states )
408
424
425
+ @staticmethod
409
426
def _create_target_seq_id_iterator (
410
- self , seq_ids : List [SeqId ]) -> Iterator [TargetSeqId ]:
427
+ seq_ids : List [SeqId ]) -> Iterator [TargetSeqId ]:
411
428
"""Create an iterator for creating target sequence ids.
412
429
Target sequence ids are distinct from sequence ids because we create a
413
430
distinct target sequence id for each proposal token to be scored.
@@ -417,8 +434,8 @@ def _create_target_seq_id_iterator(
417
434
"""
418
435
return count (start = max (seq_ids ) + 1 )
419
436
437
+ @staticmethod
420
438
def _get_token_ids_to_score (
421
- self ,
422
439
full_spec_token_ids : List [TokenId ] # shape: [k]
423
440
) -> List [List [TokenId ]]:
424
441
"""Given an int tensor of proposal token ids, return a list of
@@ -439,8 +456,6 @@ def _get_token_ids_to_score(
439
456
empty_token_ids : List [TokenId ] = []
440
457
441
458
token_ids_to_score = [empty_token_ids ]
442
- token_ids_to_score .extend ([
443
- full_spec_token_ids [:i + 1 ]
444
- for i in range (len (full_spec_token_ids ))
445
- ])
459
+ token_ids_to_score .extend (full_spec_token_ids [:i + 1 ]
460
+ for i in range (len (full_spec_token_ids )))
446
461
return token_ids_to_score
0 commit comments