Skip to content

Commit

Permalink
[Feature] [Spec decode]: Enable MLPSpeculator/Medusa and `prompt_logp…
Browse files Browse the repository at this point in the history
…robs` with ChunkedPrefill (vllm-project#10132)

Signed-off-by: NickLucche <[email protected]>
Signed-off-by: wallashss <[email protected]>
Co-authored-by: wallashss <[email protected]>
  • Loading branch information
NickLucche and wallashss authored Jan 27, 2025
1 parent 2bc3fbb commit 6116ca8
Show file tree
Hide file tree
Showing 16 changed files with 469 additions and 166 deletions.
19 changes: 16 additions & 3 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Optional, Sequence, Tuple, Union

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
Expand Down Expand Up @@ -154,6 +155,8 @@ def _check_logprobs_when_output_disabled(
spec_pos_logprob) = next(iter(spec_pos_logprobs.items()))
assert spec_pos_logprob.rank == -1
assert spec_pos_logprob.logprob == 0.0
if isinstance(spec_pos_logprob_token_id, torch.Tensor):
spec_pos_logprob_token_id = spec_pos_logprob_token_id.item()
assert spec_pos_logprob_token_id in baseline_pos_logprobs


Expand Down Expand Up @@ -244,7 +247,8 @@ def run_equality_correctness_test_tp(model,
batch_size: int,
max_output_len: int,
seed: int = 0,
temperature: float = 0.0):
temperature: float = 0.0,
logprobs: Optional[int] = None):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
Expand All @@ -257,7 +261,6 @@ def run_equality_correctness_test_tp(model,
results = []

prompts = [prompt for prompt, _ in zip(cycle(PROMPTS), range(batch_size))]

for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model,
args,
Expand All @@ -269,12 +272,14 @@ def run_equality_correctness_test_tp(model,
prompt=prompts,
max_tokens=max_output_len,
seed=seed,
temperature=temperature)
temperature=temperature,
logprobs=logprobs)

results.append({
"test":
"seeded_sampling",
"text": [choice.text for choice in completion.choices],
"logprobs": [choice.logprobs for choice in completion.choices],
"finish_reason":
[choice.finish_reason for choice in completion.choices],
"usage":
Expand All @@ -284,7 +289,15 @@ def run_equality_correctness_test_tp(model,
n = len(results) // 2
arg1_results = results[:n]
arg2_results = results[n:]
# Separate logprobs to avoid asserting exact equality.
arg1_logprobs = [r.pop("logprobs") for r in arg1_results]
arg2_logprobs = [r.pop("logprobs") for r in arg2_results]

for arg1_result, arg2_result in zip(arg1_results, arg2_results):
assert arg1_result == arg2_result, (
f"Results for {model=} are not the same with {arg1=} and {arg2=}. "
f"{arg1_result=} != {arg2_result=}")
if logprobs:
for logs1, logs2 in zip(arg1_logprobs, arg2_logprobs):
for l1, l2 in zip(logs1, logs2):
assert l1.tokens == l2.tokens
10 changes: 9 additions & 1 deletion tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
tensor parallelism.
"""

from typing import Optional

import pytest
import torch

Expand Down Expand Up @@ -154,15 +156,20 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs,
"--speculative-draft-tensor-parallel-size",
"1",
])])
@pytest.mark.parametrize("logprobs", [None, 2])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
logprobs: Optional[int],
batch_size: int, seed: int):
"""Verify spec decode works well with same and different TP size for
the draft model with chunked prefill.
"""
if logprobs:
test_llm_kwargs.extend(
["--disable_logprobs_during_spec_decoding", "False"])
run_equality_correctness_test_tp(model,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand All @@ -171,4 +178,5 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs,
batch_size,
max_output_len=32,
seed=seed,
temperature=0.0)
temperature=0.0,
logprobs=logprobs)
16 changes: 10 additions & 6 deletions tests/spec_decode/e2e/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,27 @@

from vllm import SamplingParams

from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model_name": "JackFram/llama-68m",
"model_name": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
"enforce_eager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
}, {
"speculative_model": "JackFram/llama-160m",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}])
Expand All @@ -36,12 +37,15 @@
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12])
def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int, logprobs: int):
"""Verify output logprobs are equal with and without speculative decoding.
seed: int, logprobs: int, prefill_chunk_size: int):
"""Verify output logprobs are equal with and without speculative decoding,
as well as with and without chunked prefill.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, common_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down
31 changes: 24 additions & 7 deletions tests/spec_decode/e2e/test_medusa_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import pytest

from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test

# main model
Expand Down Expand Up @@ -67,12 +68,14 @@
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -119,12 +122,15 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, logprobs: int):
seed: int, logprobs: int,
prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -167,12 +173,14 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with cuda graph enabled and different
batch sizes."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -217,13 +225,15 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -267,13 +277,15 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int,
seed: int):
seed: int, prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -313,14 +325,17 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
output_len: int, seed: int,
prefill_chunk_size: int):
"""Verify that medusa speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
batch sizes.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down Expand Up @@ -361,12 +376,14 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32])
def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int,
output_len: int, seed: int):
output_len: int, seed: int, prefill_chunk_size: int):
"""Verify that speculative decoding generates the same output
with batch expansion scorer and mqa scorer.
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
per_test_common_llm_kwargs,
Expand Down
Loading

0 comments on commit 6116ca8

Please sign in to comment.