Skip to content

Commit

Permalink
Fix assertion failure in Qwen 1.5 with prefix caching enabled (vllm-p…
Browse files Browse the repository at this point in the history
…roject#3373)

Co-authored-by: Cade Daniel <[email protected]>
  • Loading branch information
chenxu2048 and cadedaniel authored Mar 14, 2024
1 parent dfc7740 commit 54be8a0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 2 deletions.
43 changes: 43 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from vllm.config import ModelConfig


def test_get_sliding_window():
TEST_SLIDING_WINDOW = 4096
# Test that the sliding window is correctly computed.
# For Qwen1.5/Qwen2, get_sliding_window() should be None
# when use_sliding_window is False.
qwen2_model_config = ModelConfig(
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
)

qwen2_model_config.hf_config.use_sliding_window = False
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert qwen2_model_config.get_sliding_window() is None

qwen2_model_config.hf_config.use_sliding_window = True
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW

mistral_model_config = ModelConfig(
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
download_dir=None,
load_format="dummy",
seed=0,
dtype="float16",
revision=None,
)
mistral_model_config.hf_config.sliding_window = None
assert mistral_model_config.get_sliding_window() is None

mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
14 changes: 12 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C

if not os.path.exists(model):
model_path = snapshot_download(model_id=model,
cache_dir=download_dir,
Expand Down Expand Up @@ -139,7 +140,7 @@ def _verify_load_format(self) -> None:
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"load format '{load_format}' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")

Expand Down Expand Up @@ -232,6 +233,15 @@ def verify_with_parallel_config(
f"({pipeline_parallel_size}).")

def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
# addition to sliding window size. We check if that field is present
# and if it's False, return None.
if (hasattr(self.hf_config, "use_sliding_window")
and not self.hf_config.use_sliding_window):
return None
return getattr(self.hf_config, "sliding_window", None)

def get_vocab_size(self) -> int:
Expand Down Expand Up @@ -624,7 +634,7 @@ def _get_and_verify_dtype(
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
raise ValueError(f"dtype '{dtype}' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")

# Verify the dtype.
Expand Down

0 comments on commit 54be8a0

Please sign in to comment.