Skip to content

Commit

Permalink
[vllm] Support speculative decoding in vllm rolling batch (#2413)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Oct 2, 2024
1 parent 77041b5 commit 0631414
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@ class VllmRbProperties(Properties):
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
use_v2_block_manager: bool = False

# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None

@field_validator('engine')
def validate_engine(cls, engine):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,27 @@ def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
enable_prefix_caching=config.enable_prefix_caching,
disable_sliding_window=config.disable_sliding_window,
max_num_seqs=config.max_rolling_batch_size,
use_v2_block_manager=config.use_v2_block_manager,
speculative_model=config.speculative_model,
speculative_model_quantization=config.
speculative_model_quantization,
speculative_draft_tensor_parallel_size=config.
speculative_draft_tensor_parallel_size,
num_speculative_tokens=config.num_speculative_tokens,
speculative_max_model_len=config.speculative_max_model_len,
speculative_disable_by_batch_size=config.
speculative_disable_by_batch_size,
ngram_prompt_lookup_max=config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=config.ngram_prompt_lookup_min,
spec_decoding_acceptance_method=config.
spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=config.
typical_acceptance_sampler_posterior_alpha,
qlora_adapter_name_or_path=config.qlora_adapter_name_or_path,
disable_logprobs_during_spec_decoding=config.
disable_logprobs_during_spec_decoding,
)


Expand Down
12 changes: 12 additions & 0 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,18 @@ def get_model_name():
"seq_length": [256],
"tokenizer": "tiiuae/falcon-11B"
},
"llama-68m-speculative-medusa": {
"max_memory_per_gpu": [25.0],
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "JackFram/llama-68m"
},
"llama-68m-speculative-eagle": {
"max_memory_per_gpu": [25.0],
"batch_size": [1, 4],
"seq_length": [256],
"tokenizer": "JackFram/llama-68m"
},
"llama-7b-unmerged-lora": {
"max_memory_per_gpu": [15.0, 15.0],
"batch_size": [3],
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/llm/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,24 @@
"option.tensor_parallel_degree": 4,
"option.enable_chunked_prefill": "true",
},
"llama-68m-speculative-medusa": {
"option.model_id": "s3://djl-llm/llama-68m/",
"option.task": "text-generation",
"option.speculative_model": "s3://djl-llm/llama-2-tiny/",
"option.num_speculative_tokens": 4,
"option.use_v2_block_manager": True,
"option.tensor_parallel_degree": 1,
"option.max_rolling_batch_size": 4,
},
"llama-68m-speculative-eagle": {
"option.model_id": "s3://djl-llm/llama-68m/",
"option.task": "text-generation",
"option.speculative_model": "abhigoyal/vllm-eagle-llama-68m-random",
"option.num_speculative_tokens": 4,
"option.use_v2_block_manager": True,
"option.tensor_parallel_degree": 1,
"option.max_rolling_batch_size": 4,
},
"llama-7b-unmerged-lora": {
"option.model_id": "s3://djl-llm/huggyllama-llama-7b",
"option.tensor_parallel_degree": "max",
Expand Down
12 changes: 12 additions & 0 deletions tests/integration/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,18 @@ def test_falcon_11b_chunked_prefill(self):
client.run(
"vllm falcon-11b-chunked-prefill --in_tokens 1200".split())

def test_llama_68m_speculative_medusa(self):
with Runner('lmi', 'llama-68m-speculative-medusa') as r:
prepare.build_vllm_model("llama-68m-speculative-medusa")
r.launch()
client.run("vllm llama-68m-speculative-medusa".split())

def test_llama_68m_speculative_eagle(self):
with Runner('lmi', 'llama-68m-speculative-eagle') as r:
prepare.build_vllm_model("llama-68m-speculative-eagle")
r.launch()
client.run("vllm llama-68m-speculative-eagle".split())


@pytest.mark.vllm
@pytest.mark.lora
Expand Down

0 comments on commit 0631414

Please sign in to comment.