Skip to content

Commit 57b7be0

Browse files
[Speculative decoding] [Multi-Step] decouple should_modify_greedy_probs_inplace (vllm-project#6971)
1 parent 99b4cf5 commit 57b7be0

File tree

8 files changed

+52
-3
lines changed

8 files changed

+52
-3
lines changed

tests/samplers/test_sampler.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import itertools
22
import random
33
from typing import Dict, List, Optional, Tuple
4-
from unittest.mock import patch
4+
from unittest.mock import Mock, patch
55

66
import pytest
77
import torch
@@ -703,3 +703,28 @@ def test_sampling_params(sampling_params: List[SamplingParams]):
703703

704704
assert tokens1[0] == tokens2[1]
705705
assert tokens1[1] == tokens2[0]
706+
707+
708+
@pytest.mark.parametrize("device", CUDA_DEVICES)
709+
def test_sampler_include_gpu_probs_tensor(device: str):
710+
set_random_seed(42)
711+
torch.set_default_device(device)
712+
batch_size = random.randint(1, 256)
713+
_, fake_logits, sampler = _prepare_test(batch_size)
714+
sampler.include_gpu_probs_tensor = True
715+
sampler.should_modify_greedy_probs_inplace = False
716+
717+
sampling_params = SamplingParams(temperature=0)
718+
719+
mock_inplace = Mock()
720+
with patch(
721+
"vllm.model_executor.layers.sampler._modify_greedy_probs_inplace",
722+
mock_inplace):
723+
724+
sampler_output = _do_sample(batch_size, fake_logits, sampler,
725+
sampling_params, device)
726+
mock_inplace.assert_not_called()
727+
728+
assert sampler_output.sampled_token_probs is not None
729+
assert sampler_output.logprobs is not None
730+
assert sampler_output.sampled_token_ids is not None

vllm/lora/layers.py

+4
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,10 @@ def org_vocab_size(self):
10671067
def include_gpu_probs_tensor(self):
10681068
return self.base_layer.include_gpu_probs_tensor
10691069

1070+
@property
1071+
def should_modify_greedy_probs_inplace(self):
1072+
return self.base_layer.should_modify_greedy_probs_inplace
1073+
10701074
def create_lora_weights(
10711075
self,
10721076
max_loras: int,

vllm/model_executor/layers/sampler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self):
5151
# containing the sampled token ids and probabilities. This is used by
5252
# speculative decoding.
5353
self.include_gpu_probs_tensor = False
54+
self.should_modify_greedy_probs_inplace = False
5455

5556
def _init_sampling_tensors(
5657
self,
@@ -177,8 +178,7 @@ def _should_modify_greedy_probs_inplace(self) -> bool:
177178
This is used by speculative decoding, which requires that the sampling
178179
method be encoded into the probability distribution.
179180
"""
180-
# Modify greedy probs if include_gpu_probs_tensor is set.
181-
return self.include_gpu_probs_tensor
181+
return self.should_modify_greedy_probs_inplace
182182

183183

184184
def _get_bin_counts_and_mask(

vllm/spec_decode/medusa_worker.py

+3
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ def init_device(self):
3535
def set_include_gpu_probs_tensor(self):
3636
pass
3737

38+
def set_should_modify_greedy_probs_inplace(self):
39+
pass
40+
3841
@torch.inference_mode()
3942
def sampler_output(
4043
self,

vllm/spec_decode/multi_step_worker.py

+4
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def set_include_gpu_probs_tensor(self) -> None:
4646
# Need include_gpu_probs_tensor for MultiStepWorker
4747
self.model_runner.model.sampler.include_gpu_probs_tensor = True
4848

49+
def set_should_modify_greedy_probs_inplace(self) -> None:
50+
self.model_runner.model.sampler.should_modify_greedy_probs_inplace = (
51+
True)
52+
4953
@torch.inference_mode()
5054
def sampler_output(
5155
self,

vllm/spec_decode/proposer_worker_base.py

+4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def set_include_gpu_probs_tensor(self) -> None:
2828
"""Implementation optional"""
2929
pass
3030

31+
def set_should_modify_greedy_probs_inplace(self) -> None:
32+
"""Implementation optional"""
33+
pass
34+
3135

3236
class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC):
3337
"""Proposer worker which does not use a model with kvcache"""

vllm/spec_decode/smaller_tp_proposer_worker.py

+6
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def set_include_gpu_probs_tensor(self) -> None:
8383
# Need include_gpu_probs_tensor for multi_step_worker
8484
self._worker.set_include_gpu_probs_tensor()
8585

86+
def set_should_modify_greedy_probs_inplace(self) -> None:
87+
if self._is_dummy:
88+
return
89+
90+
self._worker.set_should_modify_greedy_probs_inplace()
91+
8692
def load_model(self) -> None:
8793
if self._is_dummy:
8894
return

vllm/spec_decode/spec_decode_worker.py

+3
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,10 @@ def _configure_model_sampler_for_spec_decode(self):
295295
"""
296296
(self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor
297297
) = True
298+
(self.scorer_worker.model_runner.model.sampler.
299+
should_modify_greedy_probs_inplace) = True
298300
self.proposer_worker.set_include_gpu_probs_tensor()
301+
self.proposer_worker.set_should_modify_greedy_probs_inplace()
299302

300303
def determine_num_available_blocks(self) -> Tuple[int, int]:
301304
"""Determine the number of cache blocks to use.

0 commit comments

Comments
 (0)