From b87c21fc89c772d231cae97346e0457ef3bb1bf9 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Mon, 3 Mar 2025 15:40:04 +0800 Subject: [PATCH 01/33] [Misc][Platform] Move use allgather to platform (#14010) Signed-off-by: Mengqing Cao --- vllm/model_executor/layers/logits_processor.py | 10 +++------- vllm/platforms/interface.py | 13 +++++++++++++ vllm/platforms/neuron.py | 4 ++++ vllm/platforms/tpu.py | 4 ++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 2f39a0e878540..4a359725bad0f 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -8,7 +8,6 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -51,11 +50,7 @@ def __init__(self, # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - parallel_config = get_current_vllm_config().parallel_config - self.use_all_gather = current_platform.is_tpu() \ - or current_platform.is_neuron() \ - or envs.VLLM_USE_V1 \ - or parallel_config.distributed_executor_backend == "external_launcher" # noqa + self.use_all_gather = current_platform.use_all_gather() def forward( self, @@ -83,7 +78,8 @@ def forward( logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None: + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: logits = _apply_logits_processors(logits, sampling_metadata) return logits diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d81a66e4bcb1d..e7e55e11775c5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -330,6 +330,19 @@ def get_device_communicator_cls(cls) -> str: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def use_all_gather(cls) -> bool: + """ + Whether to use allgather in LogitsProcessor to gather the logits. + """ + import vllm.envs as envs + from vllm.config import get_current_vllm_config + + parallel_config = get_current_vllm_config().parallel_config + return (envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend + == "external_launcher") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 5a03f5f7acbc1..b2eadb7932f33 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -55,3 +55,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") return False + + @classmethod + def use_all_gather(cls) -> bool: + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index cdf835a52c0c1..0b66b52713e97 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -119,3 +119,7 @@ def is_pin_memory_available(cls): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa + + @classmethod + def use_all_gather(cls) -> bool: + return True From f35f8e2242db224a92a14e084d502eec67d56da9 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 3 Mar 2025 00:43:14 -0800 Subject: [PATCH 02/33] [Build] Make sure local main branch is synced when VLLM_USE_PRECOMPILED=1 (#13921) Signed-off-by: Cody Yu --- setup.py | 28 ++++++++++++++++++- tests/standalone_tests/python_only_compile.sh | 2 +- vllm/envs.py | 8 +++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index cd17709b57ef3..1a6f2ffd8524a 100755 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ import ctypes import importlib.util +import json import logging import os import re @@ -269,9 +270,32 @@ class repackage_wheel(build_ext): """Extracts libraries and other files from an existing wheel.""" def get_base_commit_in_main_branch(self) -> str: - import subprocess + # Force to use the nightly wheel. This is mainly used for CI testing. + if envs.VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: + return "nightly" try: + # Get the latest commit hash of the upstream main branch. + resp_json = subprocess.check_output([ + "curl", "-s", + "https://api.github.com/repos/vllm-project/vllm/commits/main" + ]).decode("utf-8") + upstream_main_commit = json.loads(resp_json)["sha"] + + # Check if the local main branch is up-to-date. This is to ensure + # the base commit we found is the most recent commit on the main + # branch. + local_main_commit = subprocess.check_output( + ["git", "rev-parse", "main"]).decode("utf-8").strip() + if local_main_commit != upstream_main_commit: + raise ValueError( + f"Local main branch ({local_main_commit}) is not " + "up-to-date with upstream main branch " + f"({upstream_main_commit}). Please pull the latest " + "changes from upstream main branch first.") + + # Then get the commit hash of the current branch that is the same as + # the upstream main commit. current_branch = subprocess.check_output( ["git", "branch", "--show-current"]).decode("utf-8").strip() @@ -279,6 +303,8 @@ def get_base_commit_in_main_branch(self) -> str: ["git", "merge-base", "main", current_branch]).decode("utf-8").strip() return base_commit + except ValueError as err: + raise ValueError(err) from None except Exception as err: logger.warning( "Failed to get the base commit in the main branch. " diff --git a/tests/standalone_tests/python_only_compile.sh b/tests/standalone_tests/python_only_compile.sh index f00895c0997f1..ec1bcbcc58a0f 100644 --- a/tests/standalone_tests/python_only_compile.sh +++ b/tests/standalone_tests/python_only_compile.sh @@ -18,7 +18,7 @@ apt autoremove -y echo 'import os; os.system("touch /tmp/changed.file")' >> vllm/__init__.py -VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . +VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL=1 VLLM_USE_PRECOMPILED=1 pip3 install -vvv -e . # Run the script python3 -c 'import vllm' diff --git a/vllm/envs.py b/vllm/envs.py index bf64cd70674da..f6c038967b698 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -60,12 +60,12 @@ MAX_JOBS: Optional[str] = None NVCC_THREADS: Optional[str] = None VLLM_USE_PRECOMPILED: bool = False + VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL: bool = False VLLM_NO_DEPRECATION_WARNING: bool = False VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False CMAKE_BUILD_TYPE: Optional[str] = None VERBOSE: bool = False VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False - VLLM_TEST_FORCE_FP8_MARLIN: bool = False VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[list[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None @@ -148,6 +148,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")) or bool( os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")), + # Whether to force using nightly wheel in python build. + # This is used for testing the nightly wheel in python build. + "VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL": + lambda: bool(int(os.getenv("VLLM_TEST_USE_PRECOMPILED_NIGHTLY_WHEEL", "0")) + ), + # CMake build type # If not set, defaults to "Debug" or "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo" From 4167252eafa214d6aa4ecfb627d7b16a31095f08 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 16:15:27 +0000 Subject: [PATCH 03/33] [V1] Refactor parallel sampling support (#13774) Signed-off-by: Mark McLoughlin --- vllm/v1/engine/async_llm.py | 61 ++--- vllm/v1/engine/llm_engine.py | 74 ++---- vllm/v1/engine/output_processor.py | 181 +++++++++------ vllm/v1/engine/parallel_sampling.py | 344 ++++------------------------ vllm/v1/metrics/stats.py | 5 +- 5 files changed, 201 insertions(+), 464 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ab3cdc4ee295d..954f74c3fdaef 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -25,7 +25,7 @@ from vllm.utils import cdiv, kill_process_tree from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import generate_parallel_sampling_async +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, @@ -145,25 +145,30 @@ async def add_request( """Add new request to the AsyncLLM.""" # 1) Create a new output queue for the request. - if self.output_processor.is_request_active(request_id): - raise ValueError(f"Request id {request_id} already running.") queue: asyncio.Queue[RequestOutput] = asyncio.Queue() - # 2) Convert Input --> Request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) + # 2) Fan out child requests (for n>1) + parent_req = ParentRequest.from_params(request_id, params) + n = params.n if isinstance(params, SamplingParams) else 1 + for idx in range(n): + if parent_req is not None: + request_id, params = parent_req.get_child_info(idx) - # 3) Add the request to OutputProcessor (this process). - self.output_processor.add_request(request, queue) + # 3) Convert Input --> Request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) - # 4) Add the EngineCoreRequest to EngineCore (separate process). - await self.engine_core.add_request_async(request) + # 4) Add the request to OutputProcessor (this process). + self.output_processor.add_request(request, parent_req, idx, queue) - if self.log_requests: - logger.info("Added request %s.", request_id) + # 5) Add the EngineCoreRequest to EngineCore (separate process). + await self.engine_core.add_request_async(request) + + if self.log_requests: + logger.info("Added request %s.", request_id) return queue @@ -172,7 +177,7 @@ async def add_request( # requests we don't need to send multiple messages to core proc, # and so we don't need multiple streams which then get # re-multiplexed in the API server anyhow. - async def _generate( + async def generate( self, prompt: PromptType, sampling_params: SamplingParams, @@ -243,30 +248,6 @@ async def _generate( await self.abort(request_id) raise - def generate( - self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> AsyncGenerator[RequestOutput, None]: - kwargs = dict(prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - if sampling_params.n is None or sampling_params.n == 1: - return self._generate(**kwargs) - else: - # Special handling for parallel sampling requests - return generate_parallel_sampling_async(generate=self._generate, - **kwargs) - async def _run_output_handler(self): """Background loop: pulls from EngineCore and pushes to AsyncStreams.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 2e76694a7f512..99b97ac8e6c46 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -22,7 +22,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.output_processor import OutputProcessor -from vllm.v1.engine.parallel_sampling import SyncParallelSamplingManager +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor @@ -50,9 +50,6 @@ def __init__( self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config - # Bookkeeping for parallel sampling requests - self.parallel_manager = SyncParallelSamplingManager() - # important: init dp group before init the engine_core self.parallel_config = vllm_config.parallel_config self.dp_enabled = self.parallel_config.data_parallel_size > 1 # noqa @@ -120,8 +117,7 @@ def from_engine_args( multiprocess_mode=enable_multiprocessing) def get_num_unfinished_requests(self) -> int: - return self.parallel_manager.get_num_unfinished_requests( - self.output_processor.get_num_unfinished_requests()) + return self.output_processor.get_num_unfinished_requests() def has_unfinished_requests(self) -> bool: has_unfinished = self.output_processor.has_unfinished_requests() @@ -157,48 +153,25 @@ def add_request( prompt_adapter_request: Optional[PromptAdapterRequest] = None, priority: int = 0, ) -> None: - """Add request.""" - kwargs = dict(request_id=request_id, - prompt=prompt, - params=params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) - # Handle parallel sampling requests differently. - if params is None or isinstance(params, - PoolingParams) or params.n == 1: - self._add_request(**kwargs) - else: - # Special handling for parallel sampling requests - self.parallel_manager.add_request_parallel_sampling( - add_request=self._add_request, **kwargs) - - def _add_request( - self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add request, `n=1`""" - # 1) Process raw inputs into the request. - request = self.processor.process_inputs(request_id, prompt, params, - arrival_time, lora_request, - trace_headers, - prompt_adapter_request, - priority) - - # 2) Make a new RequestState and queue. - self.output_processor.add_request(request) - - # 3) Add the request to EngineCore. - self.engine_core.add_request(request) + # 1) Fan out child requests (for n>1) + parent_req = ParentRequest.from_params(request_id, params) + n = params.n if isinstance(params, SamplingParams) else 1 + for idx in range(n): + if parent_req is not None: + request_id, params = parent_req.get_child_info(idx) + + # 2) Process raw inputs into the request. + request = self.processor.process_inputs(request_id, prompt, params, + arrival_time, lora_request, + trace_headers, + prompt_adapter_request, + priority) + + # 3) Make a new RequestState and queue. + self.output_processor.add_request(request, parent_req, idx) + + # 3) Add the request to EngineCore. + self.engine_core.add_request(request) def step(self) -> list[RequestOutput]: @@ -217,10 +190,7 @@ def step(self) -> list[RequestOutput]: # 3) Abort any reqs that finished due to stop strings. self.engine_core.abort_requests(processed_outputs.reqs_to_abort) - request_outputs = processed_outputs.request_outputs - - # 4) Process unfinished parallel sampling requests - return self.parallel_manager.step(request_outputs) + return processed_outputs.request_outputs def get_model_config(self): return self.model_config diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 22bbb8a0f5b47..4e1d1e3bf51bc 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -4,13 +4,14 @@ from dataclasses import dataclass from typing import Optional, Union -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.logprobs import LogprobsProcessor +from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, RequestStateStats) @@ -27,6 +28,8 @@ class RequestState: def __init__( self, request_id: str, + parent_req: Optional[ParentRequest], + request_index: int, lora_name: Optional[str], output_kind: RequestOutputKind, prompt: Optional[str], @@ -38,6 +41,8 @@ def __init__( log_stats: bool, ): self.request_id = request_id + self.parent_req = parent_req + self.request_index = request_index self.lora_name = lora_name self.output_kind = output_kind self.prompt = prompt @@ -56,11 +61,15 @@ def from_new_request( cls, tokenizer: AnyTokenizer, request: EngineCoreRequest, + parent_req: Optional[ParentRequest], + request_index: int, queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, ) -> "RequestState": return cls( request_id=request.request_id, + parent_req=parent_req, + request_index=request_index, lora_name=(request.lora_request.name if request.lora_request is not None else None), output_kind=request.sampling_params.output_kind, @@ -79,6 +88,88 @@ def from_new_request( log_stats=log_stats, ) + def make_request_output( + self, + new_token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> Optional[RequestOutput]: + + finished = finish_reason is not None + output_kind = self.output_kind + final_only = output_kind == RequestOutputKind.FINAL_ONLY + + # In follow up, we will switch to invariant where EngineCore + # does not stream partial prefills. + if not finished and (self.is_prefilling or final_only): + # Only the final output is required in FINAL_ONLY mode. + return None + + def new_request_output(request_id: str) -> RequestOutput: + return self._new_request_output(request_id, finished) + + completion_output = self._new_completion_output( + new_token_ids, finish_reason, stop_reason) + + if self.parent_req is not None: + return self.parent_req.make_request_output(final_only, + completion_output, + new_request_output) + + request_output = new_request_output(self.request_id) + request_output.outputs.append(completion_output) + return request_output + + def _new_request_output( + self, + request_id: str, + finished: bool, + ) -> RequestOutput: + + if self.output_kind == RequestOutputKind.DELTA: + # Side effect: logprobs processor forgets prompt logprobs + prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() + else: + prompt_logprobs = self.logprobs_processor.prompt_logprobs + + return RequestOutput( + request_id=request_id, + prompt=self.prompt, + prompt_token_ids=self.prompt_token_ids, + prompt_logprobs=prompt_logprobs, + outputs=[], + finished=finished, + ) + + def _new_completion_output( + self, + token_ids: list[int], + finish_reason: Optional[FinishReason], + stop_reason: Union[int, str, None], + ) -> CompletionOutput: + + finished = finish_reason is not None + delta = self.output_kind == RequestOutputKind.DELTA + + # Prepare text and token_ids, based on delta mode + text = self.detokenizer.get_next_output_text(finished, delta) + if not delta: + token_ids = self.detokenizer.output_token_ids + + # Prepare logprobs, based on delta mode + logprobs = self.logprobs_processor.logprobs + if delta and logprobs: + logprobs = logprobs[-len(token_ids):] + + return CompletionOutput( + index=self.request_index, + text=text, + token_ids=token_ids, + logprobs=logprobs, + cumulative_logprob=self.logprobs_processor.cumulative_logprob, + finish_reason=str(finish_reason) if finished else None, + stop_reason=stop_reason if finished else None) + class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" @@ -93,9 +184,6 @@ def __init__( self.request_states: dict[str, RequestState] = {} self.lora_states = LoRARequestStates() - def is_request_active(self, request_id: str) -> bool: - return request_id in self.request_states - def get_num_unfinished_requests(self): return len(self.request_states) @@ -114,6 +202,8 @@ def abort_requests( def add_request( self, request: EngineCoreRequest, + parent_req: Optional[ParentRequest] = None, + request_index: int = 0, queue: Optional[asyncio.Queue[RequestOutput]] = None, ) -> None: request_id = request.request_id @@ -123,6 +213,8 @@ def add_request( req_state = RequestState.from_new_request( tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), request=request, + parent_req=parent_req, + request_index=request_index, queue=queue, log_stats=self.log_stats) self.request_states[request_id] = req_state @@ -202,8 +294,8 @@ def process_outputs( req_state.logprobs_processor.update_from_output(engine_core_output) # 4) Create and handle RequestOutput objects. - if request_output := self._make_request_output( - req_state, new_token_ids, finish_reason, stop_reason): + if request_output := req_state.make_request_output( + new_token_ids, finish_reason, stop_reason): if req_state.queue is not None: # AsyncLLM: put into queue for handling by generate(). req_state.queue.put_nowait(request_output) @@ -211,18 +303,17 @@ def process_outputs( # LLMEngine: return list of RequestOutputs. request_outputs.append(request_output) - # Free completed requests. - if request_output.finished: - self.request_states.pop(req_id) - if not engine_core_output.finished: - # If req not finished in EngineCore, but Detokenizer - # detected stop string, abort needed in EngineCore. - reqs_to_abort.append(req_id) + # Free completed requests. + if finish_reason is not None: + self.request_states.pop(req_id) + if not engine_core_output.finished: + # If req not finished in EngineCore, but Detokenizer + # detected stop string, abort needed in EngineCore. + reqs_to_abort.append(req_id) - # Track per-request stats - self._update_stats_from_finished(req_state, request_output, - finish_reason, - iteration_stats) + # Track per-request stats + self._update_stats_from_finished(req_state, finish_reason, + iteration_stats) self.lora_states.update_iteration_stats(iteration_stats) @@ -249,7 +340,6 @@ def _update_stats_from_output(self, req_state: RequestState, req_state.stats, lora_stats) def _update_stats_from_finished(self, req_state: RequestState, - request_output: RequestOutput, finish_reason: Optional[FinishReason], iteration_stats: Optional[IterationStats]): if iteration_stats is None: @@ -257,55 +347,8 @@ def _update_stats_from_finished(self, req_state: RequestState, assert finish_reason is not None assert req_state.stats is not None - iteration_stats.update_from_finished_request(finish_reason, - request_output, - req_state.stats) + iteration_stats.update_from_finished_request( + finish_reason=finish_reason, + num_prompt_tokens=len(req_state.prompt_token_ids), + req_stats=req_state.stats) self.lora_states.finish_request(req_state) - - @staticmethod - def _make_request_output( - request_state: RequestState, - new_token_ids: list[int], - finish_reason: Optional[FinishReason], - stop_reason: Union[int, str, None], - ) -> Optional[RequestOutput]: - - finished = finish_reason is not None - output_kind = request_state.output_kind - # In follow up, we will switch to invariant where EngineCore - # does not stream partial prefills. - if not finished and (request_state.is_prefilling - or output_kind == RequestOutputKind.FINAL_ONLY): - # Only the final output is required in FINAL_ONLY mode. - return None - - detokenizer = request_state.detokenizer - logprobs_processor = request_state.logprobs_processor - - delta = output_kind == RequestOutputKind.DELTA - logprobs = logprobs_processor.logprobs - if delta: - if logprobs: - logprobs = logprobs[-len(new_token_ids):] - # Side effect: logprobs processor forgets prompt logprobs - prompt_logprobs = logprobs_processor.pop_prompt_logprobs() - else: - prompt_logprobs = logprobs_processor.prompt_logprobs - - request_output = RequestOutput.new( - request_id=request_state.request_id, - prompt=request_state.prompt, - prompt_token_ids=request_state.prompt_token_ids, - text=detokenizer.get_next_output_text(finished, delta), - token_ids=new_token_ids if delta else detokenizer.output_token_ids, - logprobs=logprobs, - prompt_logprobs=prompt_logprobs, - cumulative_logprob=logprobs_processor.cumulative_logprob, - finished=finished, - ) - if finished: - completion_output = request_output.outputs[0] - completion_output.finish_reason = str(finish_reason) - completion_output.stop_reason = stop_reason - - return request_output diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index 291360771b54f..adced8973b033 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -1,69 +1,46 @@ # SPDX-License-Identifier: Apache-2.0 -from collections.abc import AsyncGenerator, Mapping from copy import copy -from typing import Optional, Protocol, Union +from typing import Callable, Optional, Union -from vllm.inputs import PromptType -from vllm.lora.request import LoRARequest -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams -from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import RequestOutputKind, SamplingParams -from vllm.utils import merge_async_iterators +from vllm.sampling_params import SamplingParams -class AsyncGenerateMethodType(Protocol): - - def __call__(self, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0) -> AsyncGenerator[RequestOutput, None]: - ... - - -class SyncAddRequestMethodType(Protocol): - - def __call__(self, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0) -> None: - ... - - -class ParallelSamplingRequest: +class ParentRequest: """Info, state & processing for parallel sampling request. - + Store parent request ID and sampling params. Facilitate generating child request sampling params. - Transform child request outputs into parent request - outputs. - When stream mode is disabled, then `self.request_output` - aggregates child request completions. """ request_id: str sampling_params: SamplingParams + + # To aggregate child completions when not streaming + output_aggregator: Optional[RequestOutput] + + # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] - request_output: Optional[RequestOutput] - num_finished_completions: int def __init__(self, request_id: str, sampling_params: SamplingParams) -> None: self.request_id = request_id self.sampling_params = sampling_params + + self.output_aggregator = None self.cached_child_sampling_params = None - self.request_output = None - self.num_finished_completions = 0 + + @classmethod + def from_params( + cls, + request_id: str, + params: Union[SamplingParams, PoolingParams], + ) -> Optional['ParentRequest']: + if not isinstance(params, SamplingParams) or params.n == 1: + return None + return cls(request_id, params) def _get_child_sampling_params( self, @@ -96,47 +73,6 @@ def _get_child_sampling_params( child_sampling_params.seed = seed + index return child_sampling_params - def _add_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> None: - """Aggregate a parallel sampling child - request output. - - Non-stream-mode (`output_kind == FINAL_ONLY`) - only. Inject correct parent request ID and - completion index. - - Args: - child_req_output: a single request output - from a parallel sampling - child request. - index: index within `n` child - """ - self.num_finished_completions += 1 - new_completion = child_req_output.outputs[0] - new_completion.index = index - if self.request_output is None: - # Save the first request output; reinstate - # original request ID; metrics are not - # supported for parallel sampling - child_req_output.request_id = self.request_id - child_req_output.metrics = None - self.request_output = child_req_output - else: - # Aggregate additional completion into request output - # Note: will be sorted by index later - self.request_output.outputs.append(new_completion) - - def _get_final_request_output(self) -> RequestOutput: - """Invariant: parent completion outputs sorted by index""" - assert self.request_output is not None - self.request_output.finished = True - self.request_output.outputs = sorted(self.request_output.outputs, - key=lambda x: x.index) - return self.request_output - def get_child_info(self, index: int) -> tuple[str, SamplingParams]: """Get child request ID and sampling params. @@ -149,227 +85,35 @@ def get_child_info(self, index: int) -> tuple[str, SamplingParams]: return (f"{index}_{self.request_id}", self._get_child_sampling_params(index)) - def process_output( - self, - child_req_output: RequestOutput, - index: int, - ) -> Optional[RequestOutput]: - """Filter, aggregate and transform parallel sampling - child request outputs. - - If the parent request has `stream=false` - (`output_kind == FINAL_ONLY`), each child will also have - `output_kind == FINAL_ONLY`. All child request outputs - must be aggregated into a single request output, with - multiple completions. This request output is only returned - once `n` completions are aggregated. - - If the parent request has `stream=true` - (`output_kind == DELTA`), each child will also have - `output_kind == DELTA`. All child request outputs - must be streamed directly to the caller. - - Args: - child_req_output: a single child request output - index: index within `n` child requests - - Returns: - `None`, unless a processed request output is ready to - send back to the caller. - """ - if self.output_kind != RequestOutputKind.FINAL_ONLY: - # stream=true: return child completions immediately - child_req_output.request_id = self.request_id - child_req_output.outputs[0].index = index - if child_req_output.finished: - # Parent request is complete if all child requests are - # complete. - self.num_finished_completions += 1 - child_req_output.finished = ( - self.num_finished_completions == self.n) - return child_req_output - - # stream=false: aggregate child completions - self._add_output(child_req_output, index) - if self.num_finished_completions == self.n: - # Return aggregated request output after obtaining - # all completions - return self._get_final_request_output() - return None - - async def wrap_child_async_generator( - self, - child_gen: AsyncGenerator[RequestOutput, None], - index: int, - ) -> AsyncGenerator[RequestOutput, None]: - """Output generator for a single parallel sampling - child request. - - Each parallel sampling request triggers at - least two child requests. This generator - yields zero or more request outputs to - return to the caller, as they become - available. - - Args: - child_gen: generator for child request - outputs. - index: index within the `n` child requests - - Returns: - Yields zero or more request outputs to return - to the caller. - """ - async for out in child_gen: - if req_out := self.process_output(out, index): - yield req_out - @property def n(self) -> int: return self.sampling_params.n - @property - def output_kind(self) -> RequestOutputKind: - return self.sampling_params.output_kind - - -class SyncParallelSamplingManager: - - def __init__(self): - # Parent req ID -> parent request manager - self.parent_reqs: dict[str, ParallelSamplingRequest] = {} - # Child req ID -> (child req index, parent req ID) - self.child_reqs: dict[str, tuple[int, str]] = {} - - def _register_parent_request(self, req: ParallelSamplingRequest) -> None: - """Register parallel sampling parent request.""" - self.parent_reqs[req.request_id] = req - - def _register_child_request(self, req_id: str, child_req_id: str, - index: int) -> None: - """Register parallel sampling child request with parent. - - Args: - req_id: parent request ID - child_req_id: child request ID - index: child request index within `n` child requests - """ - self.child_reqs[child_req_id] = (index, req_id) - - def get_num_unfinished_requests(self, num_core_reqs: int) -> int: - """Get the number of unfinished requests, correcting for parallel - sampling. - - Args: - num_core_reqs: The number of unfinished requests in the engine core. - - Returns: - Number of unfinished requests, where each parallel sampling req - counts as 1 - """ - return num_core_reqs + len(self.parent_reqs) - len(self.child_reqs) - - def add_request_parallel_sampling( + def make_request_output( self, - add_request: SyncAddRequestMethodType, - request_id: str, - prompt: PromptType, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, - ) -> None: - """Add sync parallel sampling request.""" - req = ParallelSamplingRequest(request_id, params) - self._register_parent_request(req) - # Add n child requests with unique request IDs & random seeds and n=1 - for idx in range(req.n): - child_req_id, child_params = req.get_child_info(idx) - self._register_child_request(request_id, child_req_id, idx) - add_request(request_id=child_req_id, - prompt=prompt, - params=child_params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority) # type: ignore - - def step( - self, - outputs: list[RequestOutput], - ) -> list[RequestOutput]: - """Build parallel sampling request outputs. - - Extract child request outputs, aggregate them - into parent request output, and return parent - output when complete. - - Do not modify `n=1` requests. - - Args: - outputs: step request outputs. Mix of child request - outputs & `n=1` request outputs. + final_only: bool, + completion_output: CompletionOutput, + new_request_output: Callable[[str], RequestOutput], + ) -> Optional[RequestOutput]: + # Use an existing RequestOutput if we're aggregating + request_output = self.output_aggregator - Return: - List of parallel sampling parent request outputs & - unmodified `n=1` request outputs passed-thru from input. - """ - if not (self.parent_reqs and outputs): - # Return unmodified - return outputs - agg_outputs = [] - for output in outputs: - req_id = output.request_id - if child_req_entry := self.child_reqs.get(req_id, None): - # For each parallel sampling child request output: - (index, parent_req_id) = child_req_entry - req = self.parent_reqs[parent_req_id] - # Update parallel sampling request - if out := req.process_output(output, index): - # Return parent request output if complete; - # cleanup parent request bookkeeping. - agg_outputs.append(out) - del self.parent_reqs[parent_req_id] - # Cleanup child request bookkeeping. - del self.child_reqs[req_id] - else: - # Not a parallel sampling request output - agg_outputs.append(output) - return agg_outputs + # Make new RequestOutput otherwise + if request_output is None: + request_output = new_request_output(self.request_id) + # Add a new completion + request_output.outputs.append(completion_output) -async def generate_parallel_sampling_async( - generate: AsyncGenerateMethodType, - prompt: PromptType, - sampling_params: SamplingParams, - request_id: str, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - priority: int = 0, -) -> AsyncGenerator[RequestOutput, None]: - """Generate completions for async parallel sampling requests.""" - parent_req = ParallelSamplingRequest(request_id, sampling_params) + # If not streaming, aggregate until all child requests complete + if final_only and len(request_output.outputs) != self.n: + self.output_aggregator = request_output + return None - # Aggregate generators for n child requests - gens: list[AsyncGenerator[RequestOutput, None]] = [] - for idx in range(parent_req.n): - child_req_id, child_params = parent_req.get_child_info(idx) - child_gen = generate( - prompt=prompt, - sampling_params=child_params, - request_id=child_req_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - priority=priority, - ) # type: ignore - gen = parent_req.wrap_child_async_generator(child_gen, idx) - gens.append(gen) + # We're done aggregating + self.output_aggregator = None - # Merge generators - async for _, out in merge_async_iterators(*gens): - yield out + # Parent completion output list must be sorted by index + request_output.outputs = sorted(request_output.outputs, + key=lambda x: x.index) + return request_output diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index 625edb607467b..abdca95670e11 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: - from vllm.outputs import RequestOutput from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.output_processor import RequestState @@ -150,7 +149,7 @@ def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], self.num_preempted_reqs += 1 def update_from_finished_request(self, finish_reason: "FinishReason", - request_output: "RequestOutput", + num_prompt_tokens: int, req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) @@ -172,7 +171,7 @@ def update_from_finished_request(self, finish_reason: "FinishReason", finished_req = \ FinishedRequestStats(finish_reason=finish_reason, e2e_latency=e2e_latency, - num_prompt_tokens=len(request_output.prompt_token_ids), + num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, queued_time=queued_time, prefill_time=prefill_time, From 98175b281681a91812c26e4f336d2ab9772a4afe Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:03:05 +0000 Subject: [PATCH 04/33] Improve the docs for `TransformersModel` (#14147) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- docs/source/models/supported_models.md | 68 +++++++++++++++++++------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 0e93a15b84fc9..29ed24cfdb5c3 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -14,8 +14,11 @@ Alongside each architecture, we include some popular models that use it. By default, vLLM loads models from [HuggingFace (HF) Hub](https://huggingface.co/models). -To determine whether a given model is supported, you can check the `config.json` file inside the HF repository. -If the `"architectures"` field contains a model architecture listed below, then it should be supported in theory. +To determine whether a given model is natively supported, you can check the `config.json` file inside the HF repository. +If the `"architectures"` field contains a model architecture listed below, then it should be natively supported. + +Models do not _need_ to be natively supported to be used in vLLM. +The enables you to run models directly using their Transformers implementation (or even remote code on the Hugging Face Model Hub!). :::{tip} The easiest way to check if your model is really supported at runtime is to run the program below: @@ -40,33 +43,41 @@ If vLLM successfully returns text (for generative models) or hidden states (for Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM. Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. +(transformers-fallback)= + ### Transformers fallback -`vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! +vLLM can fallback to model implementations that are available in Transformers. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! -To check if the backend is `transformers`, you can simply do this: +To check if the backend is Transformers, you can simply do this: ```python from vllm import LLM llm = LLM(model=..., task="generate") # Name or path of your model -llm.apply_model(lambda model: print(model.__class__)) +llm.apply_model(lambda model: print(type(model))) ``` -If it is `TransformersModel` then it means it's based on `transformers`! +If it is `TransformersModel` then it means it's based on Transformers! -#### Supported features +:::{note} +vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM. +::: -##### Quantization +#### Supported features -Transformers fallback has supported most of available quantization in vLLM (except GGUF). See [Quantization page](#quantization-index) for more information about supported quantization in vllm. +The Transformers fallback explicitly supports the following features: -##### LoRA +- (except GGUF) +- +- (pipeline parallel coming soon !) -Transformers fallback has supported LoRA. The usage way is identical to how LoRA works with models supported by vLLM. If you encounter any issues, please open an issue. +#### Remote code -##### Remote code +Earlier we mentioned that the Transformers fallback enables you to run remote code models directly in vLLM. +If you are interested in this feature, this section is for you! -This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production! +Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers. +Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM! ```python from vllm import LLM @@ -74,16 +85,17 @@ llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of llm.apply_model(lambda model: print(model.__class__)) ``` -A model just needs the following two things: +To make your model compatible with the Transformers fallback, it needs: + +```{code-block} python +:caption: modeling_my_model.py -```python from transformers import PreTrainedModel from torch import nn class MyAttention(nn.Module): def forward(self, hidden_states, **kwargs): # <- kwargs are required - ... attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( @@ -102,8 +114,26 @@ class MyModel(PreTrainedModel): Here is what happens in the background: 1. The config is loaded -2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. -3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. +2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. +3. The `TransformersModel` backend is used. See , which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. + +To make your model compatible with tensor parallel, it needs: + +```{code-block} python +:caption: configuration_my_model.py + +from transformers import PretrainedConfig + +class MyConfig(PretrainedConfig): + base_model_tp_plan = { + "layers.*.self_attn.q_proj": "colwise", + ... + } +``` + +:::{tip} +`base_model_tp_plan` is a `dict` that maps fully qualified layer name patterns to tensor parallel styles (currently only `"colwise"` and `"rowwise"` are supported). +::: That's it! @@ -893,7 +923,7 @@ Currently the PaliGemma model series is implemented without PrefixLM attention m ::: :::{note} -To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`. +To use Qwen2.5-VL series models, you have to install Hugging Face Transformers library from source via `pip install git+https://github.com/huggingface/transformers`. ::: ### Pooling Models From 848a6438aed2ef8d0de3d51e8ed2f970abf57853 Mon Sep 17 00:00:00 2001 From: TJian Date: Tue, 4 Mar 2025 01:24:45 +0800 Subject: [PATCH 05/33] [ROCm] Faster Custom Paged Attention kernels (#12348) --- .buildkite/run-amd-test.sh | 1 - .../kernels/benchmark_paged_attention.py | 71 +- csrc/rocm/attention.cu | 1506 ++++++++++++----- requirements-rocm.txt | 2 +- tests/kernels/test_attention.py | 8 +- vllm/attention/backends/rocm_flash_attn.py | 4 +- 6 files changed, 1145 insertions(+), 447 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 35d2ba1f8bab4..96fcafc9dc1c1 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -77,7 +77,6 @@ echo "Commands:$commands" #ignore certain kernels tests if [[ $commands == *" kernels "* ]]; then commands="${commands} \ - --ignore=kernels/test_attention.py \ --ignore=kernels/test_attention_selector.py \ --ignore=kernels/test_blocksparse_attention.py \ --ignore=kernels/test_causal_conv1d.py \ diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d00e848243611..221d7b7d5d91b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -11,8 +11,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 1024 +NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 @torch.inference_mode() @@ -80,6 +81,12 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": + if current_platform.is_rocm(): + global PARTITION_SIZE + if not args.custom_paged_attn: + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -123,25 +130,46 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: v_scale, ) elif version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -195,6 +223,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + parser.add_argument("--custom-paged-attn", + action="store_true", + help="Use custom paged attention") args = parser.parse_args() print(args) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 82f7104a9e5ac..86029da141b36 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "cuda_compat.h" @@ -50,6 +51,9 @@ using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(_Float16)))) _Float16; +typedef float16x2 _Half2; typedef struct _Half8 { _Half4 xy[2]; } _Half8; @@ -62,23 +66,17 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; +using _B8x4 = int32_t; // used in builtins +using bit8_t = uint8_t; -////// Non temporal load stores /////// - -template -__device__ __forceinline__ T load(T* addr) { - return addr[0]; -} - -template -__device__ __forceinline__ void store(T value, T* addr) { - addr[0] = value; -} +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; template -__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, - const _B16x4& inpB, - const floatx4& inpC) { +__device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { if constexpr (std::is_same::value) { return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, blgp); @@ -90,6 +88,21 @@ __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, + cbid, blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -121,17 +134,22 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { } t16; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; + union h2cvt { + __half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + return u.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t16.b = __float2bfloat16(inp[i]); - ret[i] = t16.u; + union fcvt { + uint32_t u32; + float f32; + } u; + u.f32 = inp[i]; + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // BF16 RNE with no nan/inf check + ret[i] = uint16_t(u.u32 >> 16); } return ret; } else { @@ -149,21 +167,25 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } t1, t2, res; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; + union h2cvt { + _B16x4 b16x4; + __half2 h2[2]; + } u1, u2, s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; + s.h2[0] = u1.h2[0] + u2.h2[0]; + s.h2[1] = u1.h2[1] + u2.h2[1]; + return s.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.b = t1.b + t2.b; - ret[i] = res.u; + union fcvt { + float f32; + uint32_t i32; + } u1, u2, s; + u1.i32 = uint32_t(inp1[i]) << 16; + u2.i32 = uint32_t(inp2[i]) << 16; + s.f32 = u1.f32 + u2.f32; + ret[i] = uint16_t(s.i32 >> 16); } return ret; } else { @@ -171,53 +193,600 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } } -template -__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, - const float scale) { - union alignas(16) { - uint4 u4; - _B16x8 u16x8; - vllm::bf16_8_t b16x8; - } tmp; +__device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { + // From MI300+ platforms, we have v_cvt_pk_f32_fp8 instruction + // to convert 2 packed fp8 to 2 packed fp32 values. + // However, in MI200 platforms, we only have v_cvt_f32_fp8 + // to convert fp8 values individually. So we added + // #else case for fewer instructions (# inst=2) in MI300+, + // and fallback to + // #if case for other platforms (# inst=4). + #if defined(__gfx90a__) + float4 f32x4 = vllm::fp8::vec_conversion( + *reinterpret_cast(&inp)); + return *reinterpret_cast(&f32x4); + #else // MI3xx+ optimized builtins + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; + #endif +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { + _B16x4 ret; if constexpr (std::is_same::value) { - tmp.u4 = vllm::fp8::scaled_convert(input, scale); - return tmp.u16x8; + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0], inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2], inp[3]); + return u.b16x4; } else if constexpr (std::is_same::value) { - tmp.b16x8 = vllm::fp8::scaled_convert( - input, scale); - return tmp.u16x8; + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = inp[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; } else { static_assert(false, "unsupported 16b dtype"); } } -/////////////////////////////////////// +template +__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { + union { + _B8x8 b8x8; + _B8x4 b8x4[2]; + } tmp; + tmp.b8x8 = input; + _B16x8 ret; + for (int i = 0; i < 2; i++) { + ret.xy[i] = from_floatx4_rtz(to_float_fp8x4(tmp.b8x4[i])); + } + return ret; +} + +// grid (num_seqs, num_partitions,num_kv_heads) +// block (256) +// clang-format off +template +__global__ +__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + // clang-format on + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + + const int partition_start_token_idx = + partition_idx * T_PAR_SIZE; // partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO, 4); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; + + // for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes + // HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer DDP (Data Parallel + // Processing) terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = + sizeof(scalar_t) / + sizeof(cache_t); // 1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 4xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP] + [QK_SIZE_RATIO]; // note that 16 contiguous elements of Q should + // be fetched per lane for 8 bit cache types : + // QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each mfma16x16x16 instruction processes 16 tokens + + // can be interpreted as B8x16 for 8 bit types + _B16x8 Klocal[TLOOP][QKHELOOP]; + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each mfma takes QH16xT16x16HE across warp + // repeat mfmas across QKHELOOP dimension + // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens + // across 4 rows x 4 tokens per lane + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = + lane16id / + 4; // 16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i = 0; i < 2; i++) { + const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem / 4 / 4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] + [2 * qkratio + i]; + } + } + } + + constexpr int KX = + 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + // fetch K values + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + + float alibi_slope; + if constexpr (ALIBI_ENABLED) { + const int alibi_head_idx = wg_start_head_idx + lane16id; + alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 64/4 = 16 contiguous vtokens per lane + constexpr int VBLOCKS_PER_LANE = + 1; // assumes block size >=16, each lane can correspond to 1 block only + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = HEAD_SIZE / 16 / NWARPS; + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + // Safe to use an int32_t here assuming we are working with < 2 billion + // tokens + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP][VTLANELOOP]; // this could be B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); -// grid (num_seqs, num_partitions,num_heads/gqa_ratio) -// block (partition size) + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + + // calculate post qk mfma scale + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // multiply by k_scale if fp8 kv cache + scale2 *= *k_scale; + } + + floatx4 d_out[TLOOP]; + // qk mfma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + d_out[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + d_out[token_depth] = gcn_mfma16x16x16_instr( + Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], d_out[token_depth]); + } + } + } else { // kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + d_out[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + d_out[token_depth]); + } + } + } + } + d_out[token_depth] *= scale2; + } + + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + // apply alibi + if constexpr (ALIBI_ENABLED) { + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int i = 0; i < 4; i++) { + d_out[token_depth][i] += alibi_slope * (alibi_offset + i); + } + } + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = (local_token_idx + i < context_len) + ? d_out[token_depth][i] + : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); + } + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(d_out[token_depth][i] - qk_max) + : 0.0f; + d_out[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + exp_sum += __shfl_xor(exp_sum, mask); + } + + __syncthreads(); // sync before writing to shared mem + + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + const int qk_max_offset = warpid * 16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS * 16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_mem[w * 16 + lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += + shared_mem[NWARPS * 16 + w * 16 + lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // disable rtz conversion due to its impact on accuracy. + constexpr bool LOGITS_RTZ_CONVERSION = false; + + // write logits to shared mem + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + d_out[token_depth] *= inv_sum_scale; + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(d_out[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(d_out[token_depth]); + } + } + + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int64_t offset = static_cast(seq_idx) * + static_cast(total_num_heads) * + static_cast(max_num_partitions) + + (static_cast(wg_start_head_idx) + + static_cast(qhead_idx)) * + static_cast(max_num_partitions) + + static_cast(partition_idx); + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; + constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; + + _B16x4 outelems[VHELOOP]; + // Softmax V mfma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems spread + // across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + // KV cache fp8 + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + // reinterpret V format as 16 elements of 8bits + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } + } + } + // apply post Softmax V mfma v_scale + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= *v_scale; + } + outelems[vhe_depth] = from_floatx4(tmp_out); + } + + __syncthreads(); + + // store Softmax-V mfma output to shared mem + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + // lane16 id head dimension; rowid head element dimension + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int offset1 = (head_elem_idx / 16) % 4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4) % 4; + for (int i = 0; i < 2; i++) { + vout[h].xy[i] = + shared_logits[offset1][offset2][local_head_idx][offset3 + i]; + } + } + + const int64_t hsz_maxp_mult = + static_cast(HEAD_SIZE * max_num_partitions); + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int64_t out_head_idx = + static_cast(wg_start_head_idx + local_head_idx); + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +// grid (num_seqs, num_partitions, num_kv_heads) +// block (256 : partition size) +// each WG handles 1 partition per sequence +// clang-format off template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + // clang-format on constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -234,29 +803,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( if (partition_start_token_idx >= context_len) { return; } - constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, - // total qheads =8, so qhloop is 2 + // every 4 lanes fetch 4 different qheads + // qhloop = num loops over qhead dimension + constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO, 4); constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); + // kheloop = num loops over head_size for 16Bytes of Q/dequantized K elements constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; - constexpr int VHELOOP = - HEAD_SIZE / - WARP_SIZE; // v head_size dimension is distributed across lanes - constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 - // 8xtokens + // for SoftMax-V Gemm, V head_size dimension is distributed across warp + // vheloop = num loops to cover v head size dimension + constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; + // softmax out has warp_size tokens across warp + // vtloop = num loops to cover warp_size(64) tokens with 16Bytes of + // dequantized V elements + constexpr int VTLOOP = WARP_SIZE / 8; + // num vblocks to cover warp_size(64) v elements + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP]; - floatx4 dout[QHLOOP]; + floatx4 d_out[QHLOOP]; float qk_max[QHLOOP]; - #pragma unroll + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + for (int h = 0; h < QHLOOP; h++) { - dout[h] = {0}; + d_out[h] = {0}; qk_max[h] = -FLT_MAX; } @@ -278,25 +855,24 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int last_ctx_block = num_context_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - + // token id within partition const int local_token_idx = threadIdx.x; + // token id within sequence const int global_token_idx = partition_start_token_idx + local_token_idx; + // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - // fetch block number for q and k - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride + + // fetch k physical block number + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = @@ -304,12 +880,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vphysical_blocks[b] = block_table[vblock_idx_ctx]; } - // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + // fetch q elements + // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; - #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -323,22 +900,24 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } + // fetch k elements const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; - const int physical_block_offset = - local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset - // is already cast as _H8 + // physical_block_offset is already cast in terms of _B16x8 + const int physical_block_offset = local_token_idx % BLOCK_SIZE; + + // each K fetch is for 8 elements of cache_t which are later dequantized to + // scalar_t for fp8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); - #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } } else { + // vllm defines X as 16 Bytes of elements of cache_t constexpr int X = 16 / sizeof(cache_t); const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; - #pragma unroll for (int d = 0; d < KHELOOP; d++) { const int head_elem = d * 8; const int offset1 = head_elem / X; @@ -348,9 +927,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } + // optional alibi fetch float alibi_slope[QHLOOP]; - if (alibi_slopes != nullptr) { - #pragma unroll + if constexpr (ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -360,10 +939,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + // fetch vcache in kv cache auto case if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -372,21 +951,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } - } else { + } // if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + // fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -395,164 +973,153 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { - // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - const _B8x8 Vlocalb8 = v_ptrh8be[d]; - Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } } + #define QK_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + d_out[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[0], Klocal[x].xy[0], d_out[h]); \ + d_out[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[1], Klocal[x].xy[1], d_out[h]); \ + } + // QK mfma with Q mfma block broadcast + // Q values across head_size dimension stored across lanes + // K values across head_size dimension are stored depthwise within lane + // Q broadcast with absz, cbid of mfma instruction + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + // below only needed for head size 128 + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); + } + #undef QK_mfma + + float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); - } + // post mfma scaling for fp8 + scale2 *= *k_scale; } - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[0].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[0].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[1].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[1].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[2].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[2].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[3].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[3].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[4].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[4].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[5].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[5].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[6].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[6].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[7].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[7].xy[1], dout[h]); - if constexpr (KHELOOP > 8) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[8].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[8].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[9].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[9].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[10].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[10].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[11].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[11].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[12].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[12].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[13].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[13].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[14].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[14].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[15].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[15].xy[1], dout[h]); - } // KHELOOP>8 - dout[h] *= scale; + d_out[h] *= scale2; } - // transpose dout so that 4 token ids are in each lane, and 4 heads are across - // 4 lanes - #pragma unroll + + // transpose d_out so that 4 token ids are in each lane, and 4 heads are + // across 4 lanes for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; - #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; - // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; - tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); - // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(d_out[h][i], B, tmp, 0, 0, 0); } - dout[h] = tmp; + d_out[h] = tmp; } const int lane4_token_idx = 4 * (global_token_idx >> 2); - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll + + if constexpr (ALIBI_ENABLED) { + const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { - #pragma unroll for (int i = 0; i < 4; i++) { - dout[h][i] += alibi_slope[h] * (alibi_offset + i); + d_out[h][i] += alibi_slope[h] * (alibi_offset + i); } } } - #pragma unroll + const int bpermute_mask = 4 * (16 * ((laneid >> 2) % 4) + lane4id); + for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; - #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) - ? fmaxf(qk_max[h], dout[h][i]) + ? fmaxf(qk_max[h], d_out[h][i]) : qk_max[h]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); - } + + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&qk_max[h])); + qk_max[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); } float exp_sum[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; - #pragma unroll for (int i = 0; i < 4; i++) { - dout[h][i] = (lane4_token_idx + i < context_len) - ? __expf(dout[h][i] - qk_max[h]) - : 0.0f; - exp_sum[h] += dout[h][i]; - } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - exp_sum[h] += __shfl_xor(exp_sum[h], mask); + d_out[h][i] = (lane4_token_idx + i < context_len) + ? __expf(d_out[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += d_out[h][i]; } + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // exp_sum[h] += __shfl_xor(exp_sum[h], mask); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&exp_sum[h])); + exp_sum[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); } - #pragma unroll - for (int h = 0; h < QHLOOP; h++) { - const int head_idx = 4 * h + lane4id; - shared_qk_max[warpid][head_idx] = qk_max[h]; - shared_exp_sum[warpid][head_idx] = exp_sum[h]; + if (laneid < 4) { + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } } } // warp within context @@ -563,18 +1130,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; - #pragma unroll + // calculate qk_max and exp_sums for partition for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; - #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; - #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -587,101 +1152,94 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * __expf(qk_max[h] - global_qk_max); - dout[h] *= global_inv_sum_scale; + d_out[h] *= global_inv_sum_scale; } + constexpr bool LOGITS_RTZ_CONVERSION = false; // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - logits[h] = from_floatx4(dout[h]); + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz for faster performance with no perceivable accuracy loss + logits[h] = from_floatx4_rtz(d_out[h]); + } else { + logits[h] = from_floatx4(d_out[h]); + } } - __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; - if (warp_start_token_idx >= context_len) { // warp out of context - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context - // iterate across heads - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - floatx4 acc = {0}; - // iterate over tokens - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[1], acc); - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]); \ + } \ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[0], acc[qh]); \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[1], acc[qh]); \ + } + + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } + // SoftMax-V calculation + // logits -> token dimension is distributed across lanes + // Vlocal -> token dimension is depthwise within lane + // uses mfma instruction block broadcast for logits + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); + + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // post mfma v scale for fp8 + acc[qh] *= *v_scale; + } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); } } + + #undef SV_mfma } // warp in context __syncthreads(); + // final write to tmp_out after vout accumulation if (warpid == 0) { _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads - scalar_t* out_ptr; - int out_num_partitions; - if (context_len > partition_size) { - out_num_partitions = max_num_partitions; - out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - out_num_partitions = 1; - out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; - } - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll + // iterate over each v head elem (within head_size) for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; - #pragma unroll for (int w = 0; w < NWARPS; w++) { vout[qh][vh] = addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); } + } + } + + scalar_t* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + const int out_num_partitions = max_num_partitions; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { const int head_size_elem = vh * WARP_SIZE + laneid; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); - #pragma unroll for (int i = 0; i < 4; i++) { const int head_idx = 4 * qh + i; if (head_idx < GQA_RATIO) { @@ -692,15 +1250,15 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - } + } // warpid == 0 } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -714,18 +1272,13 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // if num_partitions==1, main kernel will write to out directly, no work in - // reduction kernel - return; - } - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; - __shared__ float shared_exp_sums[2 * WARP_SIZE]; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; if (warpid == 0) { const float* max_logits_ptr = max_logits + @@ -734,14 +1287,25 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // valid partition is the last valid partition in case threadid > num // partitions - const int valid_partition = - (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; - const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) - ? WARP_SIZE + threadIdx.x - : num_partitions - 1; - float reg_max_logit = max_logits_ptr[valid_partition]; - float reg_max_logit2 = max_logits_ptr[valid_partition2]; - float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -752,17 +1316,28 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - float rescaled_exp_sum = exp_sums_ptr[valid_partition]; - float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; - rescaled_exp_sum *= - (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; - rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) - ? expf(reg_max_logit2 - max_logit) - : 0.0f; - global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; - shared_exp_sums[threadIdx.x] = rescaled_exp_sum; - shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -839,82 +1414,117 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } - if (num_partitions > MAX_NPAR) { - idx = 0; + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; #pragma unroll - for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; - j += HEAD_SIZE) { - // lastj is last valid partition - const int lastj_offset = - (j < num_partition_offset) ? j : last_partition_offset; - tmps[idx] = tmp_out_ptr[lastj_offset]; - idx++; - } + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } #pragma unroll - for (int j = 0; j < MAX_NPAR; j++) { - acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } } } const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); acc *= inv_global_exp_sum; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = from_float(acc); + + OUTT* out_ptr = out + static_cast(seq_idx) * num_heads * HEAD_SIZE + + static_cast(head_idx) * HEAD_SIZE; + if constexpr (std::is_same::value) { + out_ptr[threadIdx.x] = + __hip_cvt_float_to_fp8(acc, vllm::fp8::fp8_type::__default_saturation, + vllm::fp8::fp8_type::__default_interpret); + } else { + out_ptr[threadIdx.x] = from_float(acc); + } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +// clang-format off template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale) { + UNREACHABLE_CODE +} + +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int num_kv_heads, + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] int max_ctx_blocks, const float* k_scale, const float* v_scale) { UNREACHABLE_CODE } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int max_num_partitions) { UNREACHABLE_CODE } +// clang-format on #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ @@ -922,8 +1532,27 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ k_scale_ptr, v_scale_ptr); +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale_ptr, v_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, max_num_partitions); + template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, + bool ALIBI_ENABLED> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -945,7 +1574,6 @@ void paged_attention_custom_launcher( ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); @@ -956,109 +1584,143 @@ void paged_attention_custom_launcher( int* context_lens_ptr = context_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - assert(max_num_partitions <= 128); - constexpr int NTHR = PARTITION_SIZE; + constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); break; case 5: - LAUNCH_CUSTOM_ATTENTION(5); + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - LAUNCH_CUSTOM_ATTENTION(6); + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - LAUNCH_CUSTOM_ATTENTION(7); + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - LAUNCH_CUSTOM_ATTENTION(8); + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - LAUNCH_CUSTOM_ATTENTION(9); + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - LAUNCH_CUSTOM_ATTENTION(10); + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - LAUNCH_CUSTOM_ATTENTION(11); + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - LAUNCH_CUSTOM_ATTENTION(12); + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - LAUNCH_CUSTOM_ATTENTION(13); + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - LAUNCH_CUSTOM_ATTENTION(14); + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - LAUNCH_CUSTOM_ATTENTION(15); + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - LAUNCH_CUSTOM_ATTENTION(16); + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } - // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); - // dim3 block2(1024); - // LAUNCH_CUSTOM_ATTENTION2; - - // reduction kernel is only required if max_context_len > partition size, - // otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max - // supported by graphing, not the actual max among all the sequences: in that - // case reduction kernel will still run but return immediately - if (max_context_len > PARTITION_SIZE) { - dim3 reduce_grid(num_heads, num_seqs); - dim3 reduce_block(head_size); - paged_attention_ll4mi_reduce_kernel - <<>>( - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, - context_lens_ptr, max_num_partitions); + + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); + // reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ + ALIBI_ENABLED) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ alibi_slopes, k_scale, v_scale); -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ @@ -1074,24 +1736,24 @@ void paged_attention_custom_launcher( break; \ } +// clang-format off void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& - tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& - key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& - value_cache, // [num_blocks, num_heads, head_size, block_size] - int64_t num_kv_heads, double scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale) { + // clang-format on const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/requirements-rocm.txt b/requirements-rocm.txt index d86e039c2326f..97985655cbf64 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -11,4 +11,4 @@ peft pytest-asyncio tensorizer>=2.9.0 runai-model-streamer==0.11.0 -runai-model-streamer-s3==0.11.0 +runai-model-streamer-s3==0.11.0 \ No newline at end of file diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 0fe10d76909ea..fc549d7a7c18d 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -25,6 +25,7 @@ # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 # flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16} DTYPES = [ torch.half, torch.bfloat16, torch.float @@ -146,6 +147,8 @@ def test_paged_attention( or (version == "rocm" and head_size not in (64, 128))): pytest.skip() + global PARTITION_SIZE + current_platform.seed_everything(seed) torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) @@ -214,6 +217,9 @@ def test_paged_attention( and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): + if current_platform.is_rocm() and version == "rocm": + PARTITION_SIZE = PARTITION_SIZE_ROCM + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -432,4 +438,4 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 3f40686ee2fda..02a2a48fe8593 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -22,7 +22,7 @@ logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 512 +_PARTITION_SIZE_ROCM = 256 _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName _ON_NAVI = "gfx1" in _GPU_ARCH _ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx942"]) @@ -885,4 +885,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) \ No newline at end of file From 91373a0d15f3d47a45ab5145467940a1664900fc Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Mon, 3 Mar 2025 17:48:11 +0000 Subject: [PATCH 06/33] Fix `head_dim` not existing in all model configs (Transformers backend) (#14141) Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/model_executor/models/transformers.py | 27 ++++++++++++---------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 61cfc566dd31a..a6bfdebb1a7e3 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -25,7 +25,6 @@ from vllm.attention import Attention from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.distributed.utils import divide from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, @@ -128,10 +127,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config self.config = config - self.vocab_size = config.vocab_size - self.unpadded_vocab_size = config.vocab_size + self.vocab_size = model_config.get_vocab_size() + self.unpadded_vocab_size = model_config.get_vocab_size() self.model: PreTrainedModel = AutoModel.from_config( self.config, @@ -145,15 +146,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.apply_base_model_tp_plan(self.model) # Attention modifications (assumes 1 attention op per hidden layer) - tp_size = get_tensor_model_parallel_world_size() + num_heads = model_config.get_num_attention_heads(parallel_config) + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.attention_instances = [ Attention( - num_heads=divide(config.num_attention_heads, tp_size), - head_size=config.head_dim, + num_heads=num_heads, + head_size=head_size, # NOTE: We use Llama scale as default, if it's set by # Transformers, it's updated in vllm_flash_attention_forward - scale=config.head_dim**-0.5, - num_kv_heads=divide(config.num_key_value_heads, tp_size), + scale=head_size**-0.5, + num_kv_heads=num_kv_heads, cache_config=cache_config, quant_config=self.quant_config, prefix=f"{i}.attn") for i in range(config.num_hidden_layers) @@ -163,7 +166,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.replace_vocab_embed_class(self.model) # ForCausalLM modifications - self.lm_head = ParallelLMHead(config.vocab_size, + self.lm_head = ParallelLMHead(self.vocab_size, config.hidden_size, quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head")) @@ -172,7 +175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.vocab_size, logit_scale) + self.vocab_size, logit_scale) self.sampler = get_sampler() def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""): @@ -203,12 +206,12 @@ def replace_vocab_embed_class(self, module: nn.Module): new_module = VocabParallelEmbedding( self.vocab_size, self.config.hidden_size, - org_num_embeddings=self.config.vocab_size, + org_num_embeddings=self.vocab_size, quant_config=None, ) log_replacement("input embedding", self.model.get_input_embeddings(), new_module) - self.model.set_input_embeddings(new_module) + module.set_input_embeddings(new_module) def forward( self, From c41d27156b7c9123bd38387afca639631dfc2ed0 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 17:50:22 +0000 Subject: [PATCH 07/33] [V0][Metrics] Remove unimplemented `vllm:tokens_total` (#14134) Signed-off-by: Mark McLoughlin --- vllm/engine/metrics.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index cb3ca7a118819..9379ba6146316 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -115,10 +115,6 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): name="vllm:generation_tokens_total", documentation="Number of generation tokens processed.", labelnames=labelnames) - self.counter_tokens = self._counter_cls( - name="vllm:tokens_total", - documentation="Number of prefill plus generation tokens processed.", - labelnames=labelnames) buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] if not vllm_config.model_config.enforce_eager: buckets = vllm_config.compilation_config.\ From 2dfdfed8a0fe5517f8d4050740c251b1c1d35eeb Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 18:25:46 +0000 Subject: [PATCH 08/33] [V0][Metrics] Deprecate some KV/prefix cache metrics (#14136) Signed-off-by: Mark McLoughlin --- vllm/engine/metrics.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 9379ba6146316..08be2cbc0b9d5 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -74,31 +74,51 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): ], multiprocess_mode="livemostrecent", ) + + # Deprecated in 0.8 - KV cache offloading is not used in V1 + # TODO: in 0.9, only enable if show_hidden_metrics=True self.gauge_scheduler_swapped = self._gauge_cls( name="vllm:num_requests_swapped", - documentation="Number of requests swapped to CPU.", + documentation=( + "Number of requests swapped to CPU. " + "DEPRECATED: KV cache offloading is not used in V1"), labelnames=labelnames, multiprocess_mode="sum") + # KV Cache Usage in % self.gauge_gpu_cache_usage = self._gauge_cls( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, multiprocess_mode="sum") + + # Deprecated in 0.8 - KV cache offloading is not used in V1 + # TODO: in 0.9, only enable if show_hidden_metrics=True self.gauge_cpu_cache_usage = self._gauge_cls( name="vllm:cpu_cache_usage_perc", - documentation="CPU KV-cache usage. 1 means 100 percent usage.", + documentation=( + "CPU KV-cache usage. 1 means 100 percent usage. " + "DEPRECATED: KV cache offloading is not used in V1"), labelnames=labelnames, multiprocess_mode="sum") - # Prefix caching block hit rate + + # Deprecated in 0.8 - KV cache offloading is not used in V1 + # TODO: in 0.9, only enable if show_hidden_metrics=True self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( name="vllm:cpu_prefix_cache_hit_rate", - documentation="CPU prefix cache block hit rate.", + documentation=( + "CPU prefix cache block hit rate. " + "DEPRECATED: KV cache offloading is not used in V1"), labelnames=labelnames, multiprocess_mode="sum") + + # Deprecated in 0.8 - replaced by queries+hits counters in V1 + # TODO: in 0.9, only enable if show_hidden_metrics=True self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( name="vllm:gpu_prefix_cache_hit_rate", - documentation="GPU prefix cache block hit rate.", + documentation=("GPU prefix cache block hit rate. " + "DEPRECATED: use vllm:gpu_prefix_cache_queries and " + "vllm:gpu_prefix_cache_queries in V1"), labelnames=labelnames, multiprocess_mode="sum") From 872db2be0e8499960cedbbc2fbcbcb2837b53be2 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 3 Mar 2025 10:34:14 -0800 Subject: [PATCH 09/33] [V1] Simplify stats logging (#14082) Signed-off-by: Nick Hill --- vllm/v1/engine/async_llm.py | 21 +++++++++++---------- vllm/v1/engine/core.py | 17 ++++------------- vllm/v1/metrics/loggers.py | 29 +++++++++++++++-------------- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 954f74c3fdaef..4c9d4cb467ae9 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import logging import os from collections.abc import AsyncGenerator, Mapping from typing import Optional, Union @@ -57,10 +58,9 @@ def __init__( self.log_stats = log_stats self.stat_loggers: list[StatLoggerBase] = [] if self.log_stats: - self.stat_loggers.extend([ - LoggingStatLogger(), - PrometheusStatLogger(vllm_config), - ]) + if logger.isEnabledFor(logging.INFO): + self.stat_loggers.append(LoggingStatLogger()) + self.stat_loggers.append(PrometheusStatLogger(vllm_config)) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -287,7 +287,7 @@ async def _run_output_handler(self): # 4) Logging. # TODO(rob): make into a coroutine and launch it in # background thread once Prometheus overhead is non-trivial. - self._log_stats( + self._record_stats( scheduler_stats=outputs.scheduler_stats, iteration_stats=iteration_stats, ) @@ -306,7 +306,7 @@ async def abort(self, request_id: str) -> None: if self.log_requests: logger.info("Aborted request %s.", request_id) - def _log_stats( + def _record_stats( self, scheduler_stats: Optional[SchedulerStats], iteration_stats: Optional[IterationStats], @@ -316,9 +316,9 @@ def _log_stats( assert scheduler_stats is not None assert iteration_stats is not None - for logger in self.stat_loggers: - logger.log(scheduler_stats=scheduler_stats, - iteration_stats=iteration_stats) + for stat_logger in self.stat_loggers: + stat_logger.record(scheduler_stats=scheduler_stats, + iteration_stats=iteration_stats) def encode( self, @@ -354,7 +354,8 @@ async def do_log_stats( scheduler_outputs=None, model_output=None, ) -> None: - logger.debug("Called do_log_stats.") + for stat_logger in self.stat_loggers: + stat_logger.log() async def check_health(self) -> None: logger.debug("Called check_health.") diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index b9bf8fac40f60..b78b903b805fb 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -316,19 +316,10 @@ def run_busy_loop(self): # Loop until process is sent a SIGINT or SIGTERM while True: # 1) Poll the input queue until there is work to do. - if not self.scheduler.has_unfinished_requests(): - while True: - try: - req = self.input_queue.get(timeout=POLLING_TIMEOUT_S) - self._handle_client_request(*req) - break - except queue.Empty: - logger.debug("EngineCore busy loop waiting.") - # Break out the loop so we can log_stats in step(). - if self.log_stats: - break - except BaseException: - raise + while not self.scheduler.has_unfinished_requests(): + logger.debug("EngineCore busy loop waiting.") + req = self.input_queue.get() + self._handle_client_request(*req) # 2) Handle any new client requests. while not self.input_queue.empty(): diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 5a2a1c30a9d58..7f6de79104841 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -21,15 +21,19 @@ class StatLoggerBase(ABC): @abstractmethod - def log(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + def record(self, scheduler_stats: SchedulerStats, + iteration_stats: IterationStats): ... + def log(self): # noqa + pass + class LoggingStatLogger(StatLoggerBase): def __init__(self): self._reset(time.monotonic()) + self.last_scheduler_stats = SchedulerStats() def _reset(self, now): self.last_log_time = now @@ -41,11 +45,6 @@ def _reset(self, now): # Prefix cache metrics. TODO: Make the interval configurable. self.prefix_caching_metrics = PrefixCachingMetrics() - def _local_interval_elapsed(self, now: float) -> bool: - # Log every _LOCAL_LOGGING_INTERVAL_SEC. - elapsed_time = now - self.last_log_time - return elapsed_time > _LOCAL_LOGGING_INTERVAL_SEC - def _track_iteration_stats(self, iteration_stats: IterationStats): # Save tracked stats for token counters. self.num_prompt_tokens.append(iteration_stats.num_prompt_tokens) @@ -56,24 +55,26 @@ def _get_throughput(self, tracked_stats: list[int], now: float) -> float: # Compute summary metrics for tracked stats return float(np.sum(tracked_stats) / (now - self.last_log_time)) - def log(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + def record(self, scheduler_stats: SchedulerStats, + iteration_stats: IterationStats): """Log Stats to standard output.""" self._track_iteration_stats(iteration_stats) self.prefix_caching_metrics.observe(scheduler_stats.prefix_cache_stats) - now = time.monotonic() - if not self._local_interval_elapsed(now): - return + self.last_scheduler_stats = scheduler_stats + def log(self): + now = time.monotonic() prompt_throughput = self._get_throughput(self.num_prompt_tokens, now) generation_throughput = self._get_throughput( self.num_generation_tokens, now) self._reset(now) + scheduler_stats = self.last_scheduler_stats + # Format and print output. logger.info( "Avg prompt throughput: %.1f tokens/s, " @@ -274,8 +275,8 @@ def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): labelnames=metrics_info.keys()).labels(**metrics_info) info_gauge.set(1) - def log(self, scheduler_stats: SchedulerStats, - iteration_stats: IterationStats): + def record(self, scheduler_stats: SchedulerStats, + iteration_stats: IterationStats): """Log to prometheus.""" self.gauge_scheduler_running.set(scheduler_stats.num_running_reqs) self.gauge_scheduler_waiting.set(scheduler_stats.num_waiting_reqs) From ae122b1cbde96c871fb74611363e04eecfbcce03 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Mon, 3 Mar 2025 19:04:45 +0000 Subject: [PATCH 10/33] [WIP][[V1][Metrics] Implement max_num_generation_tokens, request_params_n, and request_params_max_tokens metrics (#14055) Signed-off-by: Mark McLoughlin --- tests/entrypoints/openai/test_metrics.py | 6 +++ vllm/v1/engine/output_processor.py | 13 ++++++ vllm/v1/engine/parallel_sampling.py | 39 +++++++++++++++++- vllm/v1/metrics/loggers.py | 50 ++++++++++++++++++++++++ vllm/v1/metrics/stats.py | 5 +++ 5 files changed, 111 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 39ce4ba23548d..2bffd0ce138e6 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -239,6 +239,12 @@ async def test_metrics_counts(server: RemoteOpenAIServer, "vllm:request_generation_tokens_sum", "vllm:request_generation_tokens_bucket", "vllm:request_generation_tokens_count", + "vllm:request_params_n_sum", + "vllm:request_params_n_bucket", + "vllm:request_params_n_count", + "vllm:request_params_max_tokens_sum", + "vllm:request_params_max_tokens_bucket", + "vllm:request_params_max_tokens_count", "vllm:time_to_first_token_seconds_sum", "vllm:time_to_first_token_seconds_bucket", "vllm:time_to_first_token_seconds_count", diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 4e1d1e3bf51bc..75c638a854f8f 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -36,6 +36,7 @@ def __init__( prompt_token_ids: list[int], logprobs_processor: LogprobsProcessor, detokenizer: IncrementalDetokenizer, + max_tokens_param: Optional[int], arrival_time: float, queue: Optional[asyncio.Queue[RequestOutput]], log_stats: bool, @@ -50,6 +51,7 @@ def __init__( self.prompt_len = len(prompt_token_ids) self.logprobs_processor = logprobs_processor self.detokenizer = detokenizer + self.max_tokens_param = max_tokens_param self.is_prefilling = True self.queue = queue @@ -83,6 +85,8 @@ def from_new_request( tokenizer=tokenizer, request=request, ), + max_tokens_param=(request.sampling_params.max_tokens if + request.sampling_params is not None else None), arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, @@ -198,6 +202,8 @@ def abort_requests( req_state = self.request_states.pop(request_id, None) if req_state is not None: self.lora_states.abort_request(req_state) + if req_state.parent_req is not None: + req_state.parent_req.finish_child_request(request_id) def add_request( self, @@ -310,6 +316,8 @@ def process_outputs( # If req not finished in EngineCore, but Detokenizer # detected stop string, abort needed in EngineCore. reqs_to_abort.append(req_id) + if req_state.parent_req is not None: + req_state.parent_req.finish_child_request(req_id) # Track per-request stats self._update_stats_from_finished(req_state, finish_reason, @@ -350,5 +358,10 @@ def _update_stats_from_finished(self, req_state: RequestState, iteration_stats.update_from_finished_request( finish_reason=finish_reason, num_prompt_tokens=len(req_state.prompt_token_ids), + max_tokens_param=req_state.max_tokens_param, req_stats=req_state.stats) self.lora_states.finish_request(req_state) + + ParentRequest.observe_finished_request( + req_state.parent_req, iteration_stats, + req_state.stats.num_generation_tokens) diff --git a/vllm/v1/engine/parallel_sampling.py b/vllm/v1/engine/parallel_sampling.py index adced8973b033..4e2c78173b513 100644 --- a/vllm/v1/engine/parallel_sampling.py +++ b/vllm/v1/engine/parallel_sampling.py @@ -6,6 +6,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams +from vllm.v1.metrics.stats import IterationStats class ParentRequest: @@ -18,9 +19,15 @@ class ParentRequest: request_id: str sampling_params: SamplingParams + # To track the completion of child requests + child_requests: set[str] + # To aggregate child completions when not streaming output_aggregator: Optional[RequestOutput] + # To find the max number of generated tokens across all children + max_num_generation_tokens: int + # To efficiently obtain child sampling params cached_child_sampling_params: Optional[SamplingParams] @@ -29,7 +36,9 @@ def __init__(self, request_id: str, self.request_id = request_id self.sampling_params = sampling_params + self.child_requests = set() self.output_aggregator = None + self.max_num_generation_tokens = 0 self.cached_child_sampling_params = None @classmethod @@ -82,8 +91,12 @@ def get_child_info(self, index: int) -> tuple[str, SamplingParams]: Returns: (request ID, sampling_params) tuple """ - return (f"{index}_{self.request_id}", - self._get_child_sampling_params(index)) + child_req_id = f"{index}_{self.request_id}" + self.child_requests.add(child_req_id) + return (child_req_id, self._get_child_sampling_params(index)) + + def finish_child_request(self, req_id: str): + self.child_requests.remove(req_id) @property def n(self) -> int: @@ -117,3 +130,25 @@ def make_request_output( request_output.outputs = sorted(request_output.outputs, key=lambda x: x.index) return request_output + + def observe_num_generation_tokens(self, num_generation_tokens: int): + self.max_num_generation_tokens = max(num_generation_tokens, + self.max_num_generation_tokens) + return self.max_num_generation_tokens + + @staticmethod + def observe_finished_request(parent_req: Optional['ParentRequest'], + iteration_stats: IterationStats, + num_generation_tokens: int): + + n_param = parent_req.n if parent_req is not None else 1 + + if parent_req is not None: + num_generation_tokens = parent_req.observe_num_generation_tokens( + num_generation_tokens) + + # Child requests finished, we can now record to iteration stats + if parent_req is None or not parent_req.child_requests: + iteration_stats.max_num_generation_tokens_iter.append( + num_generation_tokens) + iteration_stats.n_params_iter.append(n_param) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 7f6de79104841..d02b9a5da2793 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -106,6 +106,9 @@ def __init__(self, vllm_config: VllmConfig): max_model_len = vllm_config.model_config.max_model_len + # + # Scheduler state + # self.gauge_scheduler_running = prometheus_client.Gauge( name="vllm:num_requests_running", documentation="Number of requests in model execution batches.", @@ -116,6 +119,9 @@ def __init__(self, vllm_config: VllmConfig): documentation="Number of requests waiting to be processed.", labelnames=labelnames).labels(*labelvalues) + # + # GPU cache + # self.gauge_gpu_cache_usage = prometheus_client.Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", @@ -133,6 +139,9 @@ def __init__(self, vllm_config: VllmConfig): "GPU prefix cache hits, in terms of number of cached blocks.", labelnames=labelnames).labels(*labelvalues) + # + # Counters + # self.counter_num_preempted_reqs = prometheus_client.Counter( name="vllm:num_preemptions_total", documentation="Cumulative number of preemption from the engine.", @@ -159,6 +168,9 @@ def __init__(self, vllm_config: VllmConfig): reason] = counter_request_success_base.labels(*(labelvalues + [str(reason)])) + # + # Histograms of counts + # self.histogram_num_prompt_tokens_request = \ prometheus_client.Histogram( name="vllm:request_prompt_tokens", @@ -180,6 +192,31 @@ def __init__(self, vllm_config: VllmConfig): buckets=build_cudagraph_buckets(vllm_config), labelnames=labelnames).labels(*labelvalues) + self.histogram_max_num_generation_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_max_num_generation_tokens", + documentation= + "Histogram of maximum number of requested generation tokens.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + self.histogram_n_request = \ + prometheus_client.Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + buckets=[1, 2, 5, 10, 20], + labelnames=labelnames).labels(*labelvalues) + + self.histogram_max_tokens_request = \ + prometheus_client.Histogram( + name="vllm:request_params_max_tokens", + documentation="Histogram of the max_tokens request parameter.", + buckets=build_1_2_5_buckets(max_model_len), + labelnames=labelnames).labels(*labelvalues) + + # + # Histogram of timing intervals + # self.histogram_time_to_first_token = \ prometheus_client.Histogram( name="vllm:time_to_first_token_seconds", @@ -239,6 +276,9 @@ def __init__(self, vllm_config: VllmConfig): buckets=request_latency_buckets, labelnames=labelnames).labels(*labelvalues) + # + # LoRA metrics + # self.gauge_lora_info: Optional[prometheus_client.Gauge] = None if vllm_config.lora_config is not None: self.labelname_max_lora = "max_lora" @@ -255,6 +295,9 @@ def __init__(self, vllm_config: VllmConfig): self.labelname_running_lora_adapters, ]) + # + # Cache config info metric + # self.log_metrics_info("cache_config", vllm_config.cache_config) def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): @@ -296,6 +339,11 @@ def record(self, scheduler_stats: SchedulerStats, iteration_stats.num_prompt_tokens + \ iteration_stats.num_generation_tokens) + for max_gen_tokens in iteration_stats.max_num_generation_tokens_iter: + self.histogram_max_num_generation_tokens_request.observe( + max_gen_tokens) + for n_param in iteration_stats.n_params_iter: + self.histogram_n_request.observe(n_param) for ttft in iteration_stats.time_to_first_tokens_iter: self.histogram_time_to_first_token.observe(ttft) for tpot in iteration_stats.time_per_output_tokens_iter: @@ -317,6 +365,8 @@ def record(self, scheduler_stats: SchedulerStats, finished_request.num_prompt_tokens) self.histogram_num_generation_tokens_request.observe( finished_request.num_generation_tokens) + self.histogram_max_tokens_request.observe( + finished_request.max_tokens_param) if self.gauge_lora_info is not None: running_lora_adapters = \ diff --git a/vllm/v1/metrics/stats.py b/vllm/v1/metrics/stats.py index abdca95670e11..14ec7d2d7463f 100644 --- a/vllm/v1/metrics/stats.py +++ b/vllm/v1/metrics/stats.py @@ -66,6 +66,7 @@ class FinishedRequestStats: e2e_latency: float = 0.0 num_prompt_tokens: int = 0 num_generation_tokens: int = 0 + max_tokens_param: Optional[int] = None queued_time: float = 0.0 prefill_time: float = 0.0 inference_time: float = 0.0 @@ -81,6 +82,8 @@ def __init__(self): self.num_prompt_tokens = 0 self.num_preempted_reqs = 0 self.finished_requests: list[FinishedRequestStats] = [] + self.max_num_generation_tokens_iter: list[int] = [] + self.n_params_iter: list[int] = [] self.time_to_first_tokens_iter: list[float] = [] self.time_per_output_tokens_iter: list[float] = [] self.waiting_lora_adapters: dict[str, int] = {} @@ -150,6 +153,7 @@ def update_from_events(self, req_id: str, events: list["EngineCoreEvent"], def update_from_finished_request(self, finish_reason: "FinishReason", num_prompt_tokens: int, + max_tokens_param: Optional[int], req_stats: RequestStateStats): e2e_latency = self._time_since(req_stats.arrival_time) @@ -173,6 +177,7 @@ def update_from_finished_request(self, finish_reason: "FinishReason", e2e_latency=e2e_latency, num_prompt_tokens=num_prompt_tokens, num_generation_tokens=req_stats.num_generation_tokens, + max_tokens_param=max_tokens_param, queued_time=queued_time, prefill_time=prefill_time, inference_time=inference_time, From 2b04c209ee98174f29f1fc98f0dc3222d652a7bd Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 3 Mar 2025 16:20:24 -0500 Subject: [PATCH 11/33] [Bugfix] Allow shared_experts skip quantization for DeepSeekV2/V3 (#14100) Signed-off-by: mgoin --- vllm/model_executor/models/deepseek_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 7ff61f9a1826f..cf244ff572c30 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -145,6 +145,7 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, + prefix=f"{prefix}.shared_experts", ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: From 19d98e0c7db96713f0e2201649159431177a56e2 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 3 Mar 2025 16:29:53 -0500 Subject: [PATCH 12/33] [Kernel] Optimize moe intermediate_cache usage (#13625) Signed-off-by: mgoin --- .../layers/fused_moe/fused_moe.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 00260313e72eb..5336b3c100235 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor, config = get_config_func(M) - intermediate_cache1 = torch.empty((M, top_k_num, N), - device=hidden_states.device, - dtype=hidden_states.dtype) + # We can reuse the memory between these because by the time we need + # cache3, we're done with cache1 + cache13 = torch.empty(M * top_k_num * max(N, w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache1 = cache13[:M * top_k_num * N].view( + (M, topk_ids.shape[1], N)) + intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( + (M, topk_ids.shape[1], w2.shape[1])) + + # This needs separate memory since it's used concurrently with cache1 intermediate_cache2 = torch.empty((M * top_k_num, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - intermediate_cache3 = torch.empty((M, top_k_num, w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype) if hidden_states.dtype == torch.bfloat16: compute_type = tl.bfloat16 From cd1d3c3df845fc6baba4ab5ba4d168f3d632b92d Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Tue, 4 Mar 2025 05:59:09 +0800 Subject: [PATCH 13/33] [Docs] Add GPTQModel (#14056) Signed-off-by: mgoin Co-authored-by: mgoin --- docs/source/features/quantization/auto_awq.md | 2 +- .../source/features/quantization/gptqmodel.md | 83 +++++++++++++++++++ docs/source/features/quantization/index.md | 1 + 3 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 docs/source/features/quantization/gptqmodel.md diff --git a/docs/source/features/quantization/auto_awq.md b/docs/source/features/quantization/auto_awq.md index 7001ec91467f1..b703d01953193 100644 --- a/docs/source/features/quantization/auto_awq.md +++ b/docs/source/features/quantization/auto_awq.md @@ -3,7 +3,7 @@ # AutoAWQ To create a new 4-bit quantized model, you can leverage [AutoAWQ](https://github.com/casper-hansen/AutoAWQ). -Quantizing reduces the model's precision from FP16 to INT4 which effectively reduces the file size by ~70%. +Quantization reduces the model's precision from BF16/FP16 to INT4 which effectively reduces the total model memory footprint. The main benefits are lower latency and memory usage. You can quantize your own models by installing AutoAWQ or picking one of the [6500+ models on Huggingface](https://huggingface.co/models?sort=trending&search=awq). diff --git a/docs/source/features/quantization/gptqmodel.md b/docs/source/features/quantization/gptqmodel.md new file mode 100644 index 0000000000000..34adf6512b7e2 --- /dev/null +++ b/docs/source/features/quantization/gptqmodel.md @@ -0,0 +1,83 @@ +(gptqmodel)= + +# GPTQModel + +To create a new 4-bit or 8-bit GPTQ quantized model, you can leverage [GPTQModel](https://github.com/ModelCloud/GPTQModel) from ModelCloud.AI. + +Quantization reduces the model's precision from BF16/FP16 (16-bits) to INT4 (4-bits) or INT8 (8-bits) which significantly reduces the +total model memory footprint while at-the-same-time increasing inference performance. + +Compatible GPTQModel quantized models can leverage the `Marlin` and `Machete` vLLM custom kernels to maximize batching +transactions-per-second `tps` and token-latency performance for both Ampere (A100+) and Hopper (H100+) Nvidia GPUs. +These two kernels are highly optimized by vLLM and NeuralMagic (now part of Redhat) to allow world-class inference performance of quantized GPTQ +models. + +GPTQModel is one of the few quantization toolkits in the world that allows `Dynamic` per-module quantization where different layers and/or modules within a llm model can be further optimized with custom quantization parameters. `Dynamic` quantization +is fully integrated into vLLM and backed up by support from the ModelCloud.AI team. Please refer to [GPTQModel readme](https://github.com/ModelCloud/GPTQModel?tab=readme-ov-file#dynamic-quantization-per-module-quantizeconfig-override) +for more details on this and other advanced features. + +You can quantize your own models by installing [GPTQModel](https://github.com/ModelCloud/GPTQModel) or picking one of the [5000+ models on Huggingface](https://huggingface.co/models?sort=trending&search=gptq). + +```console +pip install -U gptqmodel --no-build-isolation -v +``` + +After installing GPTQModel, you are ready to quantize a model. Please refer to the [GPTQModel readme](https://github.com/ModelCloud/GPTQModel/?tab=readme-ov-file#quantization) for further details. + +Here is an example of how to quantize `meta-llama/Llama-3.2-1B-Instruct`: + +```python +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + +model_id = "meta-llama/Llama-3.2-1B-Instruct" +quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit" + +calibration_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(1024))["text"] + +quant_config = QuantizeConfig(bits=4, group_size=128) + +model = GPTQModel.load(model_id, quant_config) + +# increase `batch_size` to match gpu/vram specs to speed up quantization +model.quantize(calibration_dataset, batch_size=2) + +model.save(quant_path) +``` + +To run an GPTQModel quantized model with vLLM, you can use [DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2](https://huggingface.co/ModelCloud/DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2) with the following command: + +```console +python examples/offline_inference/llm_engine_example.py --model DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2 +``` + +GPTQModel quantized models are also supported directly through the LLM entrypoint: + +```python +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.6, top_p=0.9) + +# Create an LLM. +llm = LLM(model="DeepSeek-R1-Distill-Qwen-7B-gptqmodel-4bit-vortex-v2") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +``` diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index 1c98620aa2145..65f438f599f19 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -12,6 +12,7 @@ supported_hardware auto_awq bnb gguf +gptqmodel int4 int8 fp8 From 79e4937c65d5f553f878293a0da50f83b3773141 Mon Sep 17 00:00:00 2001 From: iefgnoix Date: Mon, 3 Mar 2025 15:00:55 -0800 Subject: [PATCH 14/33] [v1] Add comments to the new ragged paged attention Pallas kernel (#14155) Signed-off-by: Xiongfei Wei Co-authored-by: Michael Goin --- vllm/v1/attention/backends/pallas.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index bf4a05daf2d5a..543e8487e28b8 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -11,6 +11,7 @@ AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState +# These are the 2 tunable parameters of the paged attention Pallas kernel. NUM_QUERIES_PER_BLOCK = 16 NUM_KV_PAGES_PER_BLOCK = 128 @@ -154,6 +155,9 @@ def forward( write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) query = query * self.scale + # use_kernel switches between using kernel or reference implementation + # (non kernel: https://github.com/pytorch/xla/blob/cee0820e78fc9675e2d0511db891fd44342e890d/torch_xla/experimental/custom_kernel.py#L890). + use_kernel = False output = torch.ops.xla.ragged_paged_attention( query, key_cache, @@ -164,7 +168,7 @@ def forward( attn_metadata.num_seqs, num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK, num_queries_per_block=NUM_QUERIES_PER_BLOCK, - use_kernel=False, + use_kernel=use_kernel, ) return output.reshape(num_tokens, hidden_size) From c060b7140854e2289250d01bc63204a5863fbb21 Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Mon, 3 Mar 2025 17:04:52 -0700 Subject: [PATCH 15/33] [Model] Add support for GraniteMoeShared models (#13313) Signed-off-by: Travis Johnson Co-authored-by: Cyrus Leung --- docs/source/models/supported_models.md | 5 + tests/models/registry.py | 2 + .../model_executor/models/granitemoeshared.py | 343 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 4 files changed, 351 insertions(+) create mode 100644 vllm/model_executor/models/granitemoeshared.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 29ed24cfdb5c3..409a4d1210bc3 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -298,6 +298,11 @@ See [this page](#generative-models) for more information on how to use generativ * `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. * ✅︎ * ✅︎ +- * `GraniteMoeSharedForCausalLM` + * Granite MoE Shared + * `ibm-research/moe-7b-1b-active-shared-experts` (test model) + * ✅︎ + * ✅︎ - * `GritLM` * GritLM * `parasail-ai/GritLM-7B-vllm`. diff --git a/tests/models/registry.py b/tests/models/registry.py index b5ded20c5af58..97db33b46fade 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -131,6 +131,8 @@ def check_available_online( "GPTNeoXForCausalLM": _HfExamplesInfo("EleutherAI/pythia-160m"), "GraniteForCausalLM": _HfExamplesInfo("ibm/PowerLM-3b"), "GraniteMoeForCausalLM": _HfExamplesInfo("ibm/PowerMoE-3b"), + "GraniteMoeSharedForCausalLM": _HfExamplesInfo("ibm-research/moe-7b-1b-active-shared-experts", # noqa: E501 + min_transformers_version="4.49"), # noqa: E501 "Grok1ModelForCausalLM": _HfExamplesInfo("hpcai-tech/grok-1", trust_remote_code=True), "InternLMForCausalLM": _HfExamplesInfo("internlm/internlm-chat-7b", diff --git a/vllm/model_executor/models/granitemoeshared.py b/vllm/model_executor/models/granitemoeshared.py new file mode 100644 index 0000000000000..7e2e4cdcbfa36 --- /dev/null +++ b/vllm/model_executor/models/granitemoeshared.py @@ -0,0 +1,343 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Inference-only GraniteMoeShared model. + +The architecture is the same as granitemoe but with the addition of shared +experts. +""" +from typing import Iterable, Optional, Set, Tuple + +import torch +from torch import nn +from transformers.models.granitemoeshared import GraniteMoeSharedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .granitemoe import GraniteMoeAttention, GraniteMoeMoE +from .interfaces import SupportsLoRA, SupportsPP +from .utils import make_layers, maybe_prefix + + +class GraniteMoeSharedMLP(nn.Module): + + def __init__( + self, + config: GraniteMoeSharedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.input_size = config.hidden_size + self.hidden_size = config.shared_intermediate_size + self.input_linear = MergedColumnParallelLinear( + input_size=self.input_size, + output_sizes=[self.hidden_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.input_linear") + self.output_linear = RowParallelLinear( + self.hidden_size, + self.input_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.output_linear") + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.input_linear(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states, _ = self.output_linear(hidden_states) + return hidden_states + + +class GraniteMoeSharedDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeSharedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + self.shared_mlp = None if \ + getattr(config, 'shared_intermediate_size', 0) == 0 \ + else GraniteMoeSharedMLP( + config, + quant_config=quant_config, + prefix=f"{prefix}.shared_mlp" + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + if self.shared_mlp is None: + hidden_states = self.block_sparse_moe(hidden_states) + else: + # create a copy since block_sparse_moe modifies in-place + moe_hidden_states = hidden_states.clone() + moe_hidden_states = self.block_sparse_moe(moe_hidden_states) + hidden_states = moe_hidden_states + self.shared_mlp(hidden_states) + del moe_hidden_states + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +@support_torch_compile +class GraniteMoeSharedModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeSharedDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteMoeSharedForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + self.quant_config = quant_config + + self.model = GraniteMoeSharedModel(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head")) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + self.sampler = get_sampler() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w1.weight") + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + f".block_sparse_moe.experts.{e}.w3.weight") + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + f".block_sparse_moe.experts.{e}.w2.weight") + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + elif n == 'lm_head.weight' and self.config.tie_word_embeddings: + pass + else: + new_weights[n] = p + return mixtral.MixtralForCausalLM.load_weights(self, + new_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4551d81e8a5df..3a7fcdcf7b370 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -60,6 +60,7 @@ "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), From bb5b640359cc6695cb7818a24680e226f72a4da7 Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Mon, 3 Mar 2025 19:30:23 -0600 Subject: [PATCH 16/33] [core] moe fp8 block quant tuning support (#14068) Signed-off-by: Divakar Verma --- benchmarks/kernels/benchmark_moe.py | 98 +++++++++++++------ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 88 ++++++++++++----- 2 files changed, 129 insertions(+), 57 deletions(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c862dec81fccd..0d2d304156a5b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -40,6 +40,7 @@ def benchmark_config( use_fp8_w8a8: bool, use_int8_w8a16: bool, num_iters: int = 100, + block_quant_shape: List[int] = None, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) @@ -81,8 +82,24 @@ def benchmark_config( dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + E = num_experts + N = shard_intermediate_size // 2 + K = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), + dtype=torch.float32) * factor_for_scale + w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), + dtype=torch.float32) * factor_for_scale + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) @@ -111,6 +128,7 @@ def run(): w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_quant_shape, ) # JIT compilation & warmup @@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16, + block_quant_shape) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): @@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: for config_values in product(*values): config = dict(zip(keys, config_values)) configs.append(config) + + # Remove configs that are not compatible with fp8 block quantization + # BLOCK_SIZE_K must be a multiple of block_k + # BLOCK_SIZE_N must be a multiple of block_n + if block_quant_shape is not None and not use_fp16: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + for config in configs[:]: + if config["BLOCK_SIZE_K"] % block_k != 0 or config[ + "BLOCK_SIZE_N"] % block_n != 0: + configs.remove(config) return configs def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, - search_space, is_fp16): + search_space, is_fp16, topk): N1, K1 = shard_intermediate_size, hidden_size N2, K2 = hidden_size, shard_intermediate_size // 2 - pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space, - is_fp16) - pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space, - is_fp16) + pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, + search_space, is_fp16) + pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, + search_space, is_fp16) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) return search_space @@ -372,6 +401,7 @@ def tune( use_fp8_w8a8: bool, use_int8_w8a16: bool, search_space: list[dict[str, int]], + block_quant_shape: list[int], ) -> dict[str, int]: best_config = None best_time = float("inf") @@ -380,21 +410,23 @@ def tune( search_space = prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, search_space, - is_fp16) + is_fp16, topk) with torch.cuda.device(self.device_id): for config in tqdm(search_space): try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=20) + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20, + block_quant_shape=block_quant_shape) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -436,8 +468,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, shard_intermediate_size: int, hidden_size: int, topk: int, - dtype: torch.dtype, use_fp8_w8a8: bool, - use_int8_w8a16: bool) -> None: + dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_quant_shape: List[int]) -> None: dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) @@ -445,7 +477,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str) + dtype_str, block_quant_shape) print(f"Writing best config to {filename}...") with open(filename, "w") as f: @@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def main(args: argparse.Namespace): print(args) - + block_quant_shape = None config = AutoConfig.from_pretrained( args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": @@ -474,6 +506,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + block_quant_shape = config.quantization_config['weight_block_size'] else: # Default: Mixtral. E = config.num_local_experts @@ -511,27 +544,30 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if args.tune: is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16) + search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space) - for batch_size in batch_sizes]) + "tune", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape) + for batch_size in batch_sizes]) best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16) + topk, dtype, use_fp8_w8a8, use_int8_w8a16, + block_quant_shape) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: outputs = _distribute( - "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16) - for batch_size in batch_sizes]) + "benchmark", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, block_quant_shape) + for batch_size in batch_sizes]) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") diff --git a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json index 2b1167fc71e2f..63e118746fd86 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +++ b/vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json @@ -1,28 +1,28 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "2": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 2, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, @@ -31,15 +31,15 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, "waves_per_eu": 0 @@ -49,13 +49,13 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 2, "num_stages": 2, "waves_per_eu": 0 }, "32": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 2, @@ -64,7 +64,7 @@ }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 2, @@ -73,7 +73,7 @@ }, "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, @@ -82,46 +82,82 @@ }, "96": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, + "GROUP_SIZE_M": 8, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0 }, "128": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }, "256": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, + "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }, "512": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, - "num_warps": 8, + "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 }, "1024": { "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 8, - "num_warps": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, "num_stages": 2, "waves_per_eu": 0 } From 989f4f430cd74a14d539d8b59b9d239301f1bdcd Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 3 Mar 2025 19:09:34 -0800 Subject: [PATCH 17/33] [Misc] Remove lru_cache in NvmlCudaPlatform (#14156) Signed-off-by: Cody Yu --- vllm/platforms/cuda.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bffa113cab899..00bbfec1ef7ca 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,7 +4,7 @@ """ import os -from functools import lru_cache, wraps +from functools import wraps from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar, Union) @@ -284,7 +284,6 @@ def get_device_communicator_cls(cls) -> str: class NvmlCudaPlatform(CudaPlatformBase): @classmethod - @lru_cache(maxsize=8) @with_nvml_context def get_device_capability(cls, device_id: int = 0 @@ -298,7 +297,6 @@ def get_device_capability(cls, return None @classmethod - @lru_cache(maxsize=8) @with_nvml_context def has_device_capability( cls, @@ -311,14 +309,12 @@ def has_device_capability( return False @classmethod - @lru_cache(maxsize=8) @with_nvml_context def get_device_name(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) return cls._get_physical_device_name(physical_device_id) @classmethod - @lru_cache(maxsize=8) @with_nvml_context def get_device_uuid(cls, device_id: int = 0) -> str: physical_device_id = device_id_to_physical_device_id(device_id) @@ -326,7 +322,6 @@ def get_device_uuid(cls, device_id: int = 0) -> str: return pynvml.nvmlDeviceGetUUID(handle) @classmethod - @lru_cache(maxsize=8) @with_nvml_context def get_device_total_memory(cls, device_id: int = 0) -> int: physical_device_id = device_id_to_physical_device_id(device_id) From bf13d40972357e41779bc7dbfe729246b1c247c1 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Mon, 3 Mar 2025 19:44:17 -0800 Subject: [PATCH 18/33] [core] Pass all driver env vars to ray workers unless excluded (#14099) Signed-off-by: Rui Qiao --- vllm/executor/ray_distributed_executor.py | 30 ++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 108f606e2fb80..3b1735fdcf7a7 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio +import json import os from collections import defaultdict from dataclasses import dataclass @@ -48,6 +49,24 @@ class RayWorkerMetaData: class RayDistributedExecutor(DistributedExecutorBase): + """Ray-based distributed executor""" + + # These env vars are worker-specific, therefore are NOT copied + # from the driver to the workers + WORKER_SPECIFIC_ENV_VARS = { + "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + } + + config_home = envs.VLLM_CONFIG_ROOT + # This file contains a list of env vars that should not be copied + # from the driver to the Ray workers. + non_carry_over_env_vars_file = os.path.join( + config_home, "ray_non_carry_over_env_vars.json") + if os.path.exists(non_carry_over_env_vars_file): + with open(non_carry_over_env_vars_file) as f: + non_carry_over_env_vars = set(json.load(f)) + else: + non_carry_over_env_vars = set() uses_ray: bool = True @@ -311,9 +330,9 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): # Environment variables to copy from driver to workers env_vars_to_copy = [ - "VLLM_ATTENTION_BACKEND", "TPU_CHIPS_PER_HOST_BOUNDS", - "TPU_HOST_BOUNDS", "VLLM_USE_V1", "VLLM_TRACE_FUNCTION", - "VLLM_TORCH_PROFILER_DIR", "VLLM_TEST_ENABLE_EP" + v for v in envs.environment_variables + if v not in self.WORKER_SPECIFIC_ENV_VARS + and v not in self.non_carry_over_env_vars ] # Copy existing env vars to each worker's args @@ -323,9 +342,14 @@ def sort_by_driver_then_worker_ip(item: RayWorkerMetaData): if name in os.environ: args[name] = os.environ[name] + logger.info("non_carry_over_env_vars from config: %s", + self.non_carry_over_env_vars) logger.info( "Copying the following environment variables to workers: %s", [v for v in env_vars_to_copy if v in os.environ]) + logger.info( + "If certain env vars should NOT be copied to workers, add them to " + "%s file", self.non_carry_over_env_vars_file) self._env_vars_for_all_workers = ( all_args_to_update_environment_variables) From 66233af7b6e4217653f1a9952180d68376af7d2a Mon Sep 17 00:00:00 2001 From: Zhanwen Chen Date: Tue, 4 Mar 2025 00:09:22 -0500 Subject: [PATCH 19/33] Use math.prod instead of np.prod for trivial ops (#14142) --- vllm/worker/cache_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 3960392cf74ef..004b4e4b757fd 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """CacheEngine class for managing the KV cache.""" +from math import prod from typing import List -import numpy as np import torch from vllm import envs @@ -90,7 +90,7 @@ def _allocate_kv_cache( # NOTE this assumption currently only holds for MLA so we only apply # this optimization when `use_mla` is true entry_shape = kv_cache_shape[2:] - entry_size = np.prod(entry_shape) + entry_size = prod(entry_shape) alloc_entry_size = align_to_256bytes(entry_size, self.dtype) alloc_shape = (*kv_cache_shape[:2], alloc_entry_size) else: From f78c0be80a8341167a5ebf20ce4eb62421a351a6 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 4 Mar 2025 00:11:03 -0500 Subject: [PATCH 20/33] Fix benchmark_moe.py tuning for CUDA devices (#14164) --- benchmarks/kernels/benchmark_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 0d2d304156a5b..bb28c32798e2c 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -2,6 +2,7 @@ import argparse import time +from contextlib import nullcontext from datetime import datetime from itertools import product from typing import Any, TypedDict @@ -412,7 +413,8 @@ def tune( hidden_size, search_space, is_fp16, topk) - with torch.cuda.device(self.device_id): + with torch.cuda.device(self.device_id) if current_platform.is_rocm( + ) else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config( From ac65bc92dfedc6218d3b389e7cc2947faf77d902 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 18:39:16 +0800 Subject: [PATCH 21/33] [platform] add debug logging during inferring the device type (#14195) Signed-off-by: youkaichao --- vllm/platforms/__init__.py | 64 ++++++++++++++++++++++++++++++++------ 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 48cf8f7a323a6..89e69c7f5780d 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -32,6 +32,7 @@ def vllm_version_matches_substr(substr: str) -> bool: def tpu_platform_plugin() -> Optional[str]: is_tpu = False + logger.debug("Checking if TPU platform is available.") try: # While it's technically possible to install libtpu on a # non-TPU machine, this is a very uncommon scenario. Therefore, @@ -39,7 +40,9 @@ def tpu_platform_plugin() -> Optional[str]: # has TPUs. import libtpu # noqa: F401 is_tpu = True - except Exception: + logger.debug("Confirmed TPU platform is available.") + except Exception as e: + logger.debug("TPU platform is not available because: %s", str(e)) pass return "vllm.platforms.tpu.TpuPlatform" if is_tpu else None @@ -47,7 +50,7 @@ def tpu_platform_plugin() -> Optional[str]: def cuda_platform_plugin() -> Optional[str]: is_cuda = False - + logger.debug("Checking if CUDA platform is available.") try: from vllm.utils import import_pynvml pynvml = import_pynvml() @@ -60,9 +63,19 @@ def cuda_platform_plugin() -> Optional[str]: # on a GPU machine, even if in a cpu build. is_cuda = (pynvml.nvmlDeviceGetCount() > 0 and not vllm_version_matches_substr("cpu")) + if pynvml.nvmlDeviceGetCount() <= 0: + logger.debug( + "CUDA platform is not available because no GPU is found.") + if vllm_version_matches_substr("cpu"): + logger.debug("CUDA platform is not available because" + " vLLM is built with CPU.") + if is_cuda: + logger.debug("Confirmed CUDA platform is available.") finally: pynvml.nvmlShutdown() except Exception as e: + logger.debug("Exception happens when checking CUDA platform: %s", + str(e)) if "nvml" not in e.__class__.__name__.lower(): # If the error is not related to NVML, re-raise it. raise e @@ -75,23 +88,28 @@ def cuda_is_jetson() -> bool: or os.path.exists("/sys/class/tegra-firmware") if cuda_is_jetson(): + logger.debug("Confirmed CUDA platform is available on Jetson.") is_cuda = True + else: + logger.debug("CUDA platform is not available because: %s", str(e)) return "vllm.platforms.cuda.CudaPlatform" if is_cuda else None def rocm_platform_plugin() -> Optional[str]: is_rocm = False - + logger.debug("Checking if ROCm platform is available.") try: import amdsmi amdsmi.amdsmi_init() try: if len(amdsmi.amdsmi_get_processor_handles()) > 0: is_rocm = True + logger.debug("Confirmed ROCm platform is available.") finally: amdsmi.amdsmi_shut_down() - except Exception: + except Exception as e: + logger.debug("ROCm platform is not available because: %s", str(e)) pass return "vllm.platforms.rocm.RocmPlatform" if is_rocm else None @@ -99,10 +117,17 @@ def rocm_platform_plugin() -> Optional[str]: def hpu_platform_plugin() -> Optional[str]: is_hpu = False + logger.debug("Checking if HPU platform is available.") try: from importlib import util is_hpu = util.find_spec('habana_frameworks') is not None - except Exception: + if is_hpu: + logger.debug("Confirmed HPU platform is available.") + else: + logger.debug("HPU platform is not available because " + "habana_frameworks is not found.") + except Exception as e: + logger.debug("HPU platform is not available because: %s", str(e)) pass return "vllm.platforms.hpu.HpuPlatform" if is_hpu else None @@ -110,7 +135,7 @@ def hpu_platform_plugin() -> Optional[str]: def xpu_platform_plugin() -> Optional[str]: is_xpu = False - + logger.debug("Checking if XPU platform is available.") try: # installed IPEX if the machine has XPUs. import intel_extension_for_pytorch # noqa: F401 @@ -118,7 +143,9 @@ def xpu_platform_plugin() -> Optional[str]: import torch if hasattr(torch, 'xpu') and torch.xpu.is_available(): is_xpu = True - except Exception: + logger.debug("Confirmed XPU platform is available.") + except Exception as e: + logger.debug("XPU platform is not available because: %s", str(e)) pass return "vllm.platforms.xpu.XPUPlatform" if is_xpu else None @@ -126,13 +153,21 @@ def xpu_platform_plugin() -> Optional[str]: def cpu_platform_plugin() -> Optional[str]: is_cpu = False + logger.debug("Checking if CPU platform is available.") try: is_cpu = vllm_version_matches_substr("cpu") + if is_cpu: + logger.debug("Confirmed CPU platform is available because" + " vLLM is built with CPU.") if not is_cpu: import platform is_cpu = platform.machine().lower().startswith("arm") + if is_cpu: + logger.debug("Confirmed CPU platform is available" + " because the machine is ARM.") - except Exception: + except Exception as e: + logger.debug("CPU platform is not available because: %s", str(e)) pass return "vllm.platforms.cpu.CpuPlatform" if is_cpu else None @@ -140,10 +175,14 @@ def cpu_platform_plugin() -> Optional[str]: def neuron_platform_plugin() -> Optional[str]: is_neuron = False + logger.debug("Checking if Neuron platform is available.") try: import transformers_neuronx # noqa: F401 is_neuron = True - except ImportError: + logger.debug("Confirmed Neuron platform is available because" + " transformers_neuronx is found.") + except ImportError as e: + logger.debug("Neuron platform is not available because: %s", str(e)) pass return "vllm.platforms.neuron.NeuronPlatform" if is_neuron else None @@ -151,8 +190,15 @@ def neuron_platform_plugin() -> Optional[str]: def openvino_platform_plugin() -> Optional[str]: is_openvino = False + logger.debug("Checking if OpenVINO platform is available.") with suppress(Exception): is_openvino = vllm_version_matches_substr("openvino") + if is_openvino: + logger.debug("Confirmed OpenVINO platform is available" + " because vLLM is built with OpenVINO.") + if not is_openvino: + logger.debug("OpenVINO platform is not available because" + " vLLM is not built with OpenVINO.") return "vllm.platforms.openvino.OpenVinoPlatform" if is_openvino else None From 71c4b40562eb308d5fd93091373e4f913463a9eb Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 18:54:19 +0800 Subject: [PATCH 22/33] [sleep mode] error out with expandable_segments (#14189) Signed-off-by: youkaichao --- vllm/device_allocator/cumem.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/device_allocator/cumem.py b/vllm/device_allocator/cumem.py index 7f63fc1437872..0291fd9e1c88f 100644 --- a/vllm/device_allocator/cumem.py +++ b/vllm/device_allocator/cumem.py @@ -8,6 +8,7 @@ # not sure why, they are created from a different context. # the only successful approach is to call cuda driver API in C. import dataclasses +import os from contextlib import contextmanager from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -140,6 +141,12 @@ def get_instance() -> "CuMemAllocator": return CuMemAllocator.instance def __init__(self): + conf = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "") + assert "expandable_segments:True" not in conf, \ + ("Expandable segments are not compatible with memory pool. " + "Please track https://github.com/pytorch/pytorch/issues/147851 " + "for the latest updates.") + self.pointer_to_data: Dict[int, AllocationData] = {} self.current_tag: str = CuMemAllocator.default_tag self.allocator_and_pools: Dict[str, Any] = {} From 3610fb49302867af5b2598b218b3011bc9ed52aa Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 20:47:06 +0800 Subject: [PATCH 23/33] [doc] add "Failed to infer device type" to faq (#14200) Signed-off-by: youkaichao --- docs/source/getting_started/troubleshooting.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/getting_started/troubleshooting.md b/docs/source/getting_started/troubleshooting.md index 92103e65bbbb7..fdfaf9f932698 100644 --- a/docs/source/getting_started/troubleshooting.md +++ b/docs/source/getting_started/troubleshooting.md @@ -254,6 +254,10 @@ ValueError: Model architectures [''] are not supported for now. Supported But you are sure that the model is in the [list of supported models](#supported-models), there may be some issue with vLLM's model resolution. In that case, please follow [these steps](#model-resolution) to explicitly specify the vLLM implementation for the model. +## Failed to infer device type + +If you see an error like `RuntimeError: Failed to infer device type`, it means that vLLM failed to infer the device type of the runtime environment. You can check [the code](gh-file:vllm/platforms/__init__.py) to see how vLLM infers the device type and why it is not working as expected. After [this PR](gh-pr:14195), you can also set the environment variable `VLLM_LOGGING_LEVEL=DEBUG` to see more detailed logs to help debug the issue. + ## Known Issues - In `v0.5.2`, `v0.5.3`, and `v0.5.3.post1`, there is a bug caused by [zmq](https://github.com/zeromq/pyzmq/issues/2000) , which can occasionally cause vLLM to hang depending on the machine configuration. The solution is to upgrade to the latest version of `vllm` to include the [fix](gh-pr:6759). From 6247bae6c69de73e2c8a9964d23357bef724ee16 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Tue, 4 Mar 2025 09:25:27 -0500 Subject: [PATCH 24/33] [Bugfix] Restrict MacOS CPU detection (#14210) Signed-off-by: mgoin --- vllm/platforms/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 89e69c7f5780d..74ef8bd1cff1a 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -160,11 +160,11 @@ def cpu_platform_plugin() -> Optional[str]: logger.debug("Confirmed CPU platform is available because" " vLLM is built with CPU.") if not is_cpu: - import platform - is_cpu = platform.machine().lower().startswith("arm") + import sys + is_cpu = sys.platform.startswith("darwin") if is_cpu: logger.debug("Confirmed CPU platform is available" - " because the machine is ARM.") + " because the machine is MacOS.") except Exception as e: logger.debug("CPU platform is not available because: %s", str(e)) From 5db6b2c9614d70246d3386b6fb0cc655c72aeb85 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 4 Mar 2025 07:06:47 -0800 Subject: [PATCH 25/33] [V1][BugFix] Fix remaining sync engine client shutdown errors/hangs (#13869) Signed-off-by: Nick Hill --- tests/v1/engine/test_llm_engine.py | 2 - vllm/utils.py | 22 ++++---- vllm/v1/engine/core_client.py | 84 ++++++++++++++++++++---------- 3 files changed, 68 insertions(+), 40 deletions(-) diff --git a/tests/v1/engine/test_llm_engine.py b/tests/v1/engine/test_llm_engine.py index 33c884e6de357..43b16d3e5a293 100644 --- a/tests/v1/engine/test_llm_engine.py +++ b/tests/v1/engine/test_llm_engine.py @@ -15,8 +15,6 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch): """Set up VllmRunner instance.""" monkeypatch.setenv("VLLM_USE_V1", "1") - # TODO(nick): Single-proc to work around a ZMQ shutdown hang for now. - monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0") return vllm_runner( MODEL, dtype=DTYPE, diff --git a/vllm/utils.py b/vllm/utils.py index 26c9e1a908371..66d629011dd11 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -500,6 +500,10 @@ def get_open_zmq_ipc_path() -> str: return f"ipc://{base_rpc_path}/{uuid4()}" +def get_open_zmq_inproc_path() -> str: + return f"inproc://{uuid4()}" + + def get_open_port() -> int: """ Get an open port for the vLLM process to listen on. @@ -2108,12 +2112,12 @@ def get_exception_traceback(): def make_zmq_socket( ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined] path: str, - type: Any, + socket_type: Any, ) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined] """Make a ZMQ socket with the proper bind/connect semantics.""" mem = psutil.virtual_memory() - socket = ctx.socket(type) + socket = ctx.socket(socket_type) # Calculate buffer size based on system memory total_mem = mem.total / 1024**3 @@ -2127,29 +2131,27 @@ def make_zmq_socket( else: buf_size = -1 # Use system default buffer size - if type == zmq.constants.PULL: + if socket_type == zmq.constants.PULL: socket.setsockopt(zmq.constants.RCVHWM, 0) socket.setsockopt(zmq.constants.RCVBUF, buf_size) socket.connect(path) - elif type == zmq.constants.PUSH: + elif socket_type == zmq.constants.PUSH: socket.setsockopt(zmq.constants.SNDHWM, 0) socket.setsockopt(zmq.constants.SNDBUF, buf_size) socket.bind(path) else: - raise ValueError(f"Unknown Socket Type: {type}") + raise ValueError(f"Unknown Socket Type: {socket_type}") return socket @contextlib.contextmanager -def zmq_socket_ctx( - path: str, - type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined] +def zmq_socket_ctx(path: str, socket_type: Any) -> Iterator[zmq.Socket]: """Context manager for a ZMQ socket""" - ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined] + ctx = zmq.Context() # type: ignore[attr-defined] try: - yield make_zmq_socket(ctx, path, type) + yield make_zmq_socket(ctx, path, socket_type) except KeyboardInterrupt: logger.debug("Got Keyboard Interrupt.") diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index cdce14afe0b3f..55057179f3a43 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -18,8 +18,8 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, - make_zmq_socket) +from vllm.utils import (get_open_zmq_inproc_path, get_open_zmq_ipc_path, + kill_process_tree, make_zmq_socket) from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType, UtilityOutput) from vllm.v1.engine.core import EngineCore, EngineCoreProc @@ -202,10 +202,11 @@ class BackgroundResources: """Used as a finalizer for clean shutdown, avoiding circular reference back to the client object.""" - ctx: Union[zmq.Context, zmq.asyncio.Context] = None + ctx: Union[zmq.Context] = None output_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None input_socket: Union[zmq.Socket, zmq.asyncio.Socket] = None proc_handle: Optional[BackgroundProcHandle] = None + shutdown_path: Optional[str] = None def __call__(self): """Clean up background resources.""" @@ -218,8 +219,13 @@ def __call__(self): self.output_socket.close(linger=0) if self.input_socket is not None: self.input_socket.close(linger=0) - if self.ctx is not None: - self.ctx.destroy(linger=0) + if self.shutdown_path is not None: + # We must ensure that the sync output socket is + # closed cleanly in its own thread. + with self.ctx.socket(zmq.PAIR) as shutdown_sender: + shutdown_sender.connect(self.shutdown_path) + # Send shutdown signal. + shutdown_sender.send(b'') class MPClient(EngineCoreClient): @@ -261,28 +267,23 @@ def sigusr1_handler(signum, frame): self.decoder = MsgpackDecoder(EngineCoreOutputs) # ZMQ setup. - self.ctx = ( - zmq.asyncio.Context() # type: ignore[attr-defined] - if asyncio_mode else zmq.Context()) # type: ignore[attr-defined] + sync_ctx = zmq.Context() + self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx # This will ensure resources created so far are closed # when the client is garbage collected, even if an # exception is raised mid-construction. - resources = BackgroundResources(ctx=self.ctx) - self._finalizer = weakref.finalize(self, resources) + self.resources = BackgroundResources(ctx=sync_ctx) + self._finalizer = weakref.finalize(self, self.resources) - # Paths and sockets for IPC. - output_path = get_open_zmq_ipc_path() + # Paths for IPC. + self.output_path = get_open_zmq_ipc_path() input_path = get_open_zmq_ipc_path() - resources.output_socket = make_zmq_socket(self.ctx, output_path, - zmq.constants.PULL) - resources.input_socket = make_zmq_socket(self.ctx, input_path, - zmq.constants.PUSH) # Start EngineCore in background process. - resources.proc_handle = BackgroundProcHandle( + self.resources.proc_handle = BackgroundProcHandle( input_path=input_path, - output_path=output_path, + output_path=self.output_path, process_name="EngineCore", target_fn=EngineCoreProc.run_engine_core, process_kwargs={ @@ -291,8 +292,10 @@ def sigusr1_handler(signum, frame): "log_stats": log_stats, }) - self.output_socket = resources.output_socket - self.input_socket = resources.input_socket + # Create input socket. + self.resources.input_socket = make_zmq_socket(self.ctx, input_path, + zmq.constants.PUSH) + self.input_socket = self.resources.input_socket self.utility_results: dict[int, AnyFuture] = {} def shutdown(self): @@ -325,27 +328,48 @@ def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor], # Ensure that the outputs socket processing thread does not have # a ref to the client which prevents gc. - output_socket = self.output_socket + ctx = self.ctx + output_path = self.output_path decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue + shutdown_path = get_open_zmq_inproc_path() + self.resources.shutdown_path = shutdown_path + def process_outputs_socket(): + shutdown_socket = ctx.socket(zmq.PAIR) + shutdown_socket.bind(shutdown_path) + out_socket = make_zmq_socket(ctx, output_path, zmq.constants.PULL) try: + poller = zmq.Poller() + poller.register(shutdown_socket) + poller.register(out_socket) while True: - (frame, ) = output_socket.recv_multipart(copy=False) + socks = poller.poll() + if not socks: + continue + if len(socks) == 2 or socks[0][0] == shutdown_socket: + # shutdown signal, exit thread. + break + + (frame, ) = out_socket.recv_multipart(copy=False) outputs = decoder.decode(frame.buffer) if outputs.utility_output: _process_utility_output(outputs.utility_output, utility_results) else: outputs_queue.put_nowait(outputs) - except zmq.error.ContextTerminated: - # Expected when the class is GC'd / during process termination. - pass + finally: + # Close sockets. + shutdown_socket.close(linger=0) + out_socket.close(linger=0) # Process outputs from engine in separate thread. - Thread(target=process_outputs_socket, daemon=True).start() + self.output_queue_thread = Thread(target=process_outputs_socket, + name="EngineCoreOutputQueueThread", + daemon=True) + self.output_queue_thread.start() def get_output(self) -> EngineCoreOutputs: return self.outputs_queue.get() @@ -424,10 +448,13 @@ async def _start_output_queue_task(self): # Perform IO in separate task to parallelize as much as possible. # Avoid task having direct reference back to the client. self.outputs_queue = asyncio.Queue() - output_socket = self.output_socket decoder = self.decoder utility_results = self.utility_results outputs_queue = self.outputs_queue + output_path = self.output_path + output_socket = make_zmq_socket(self.ctx, output_path, + zmq.constants.PULL) + self.resources.output_socket = output_socket async def process_outputs_socket(): while True: @@ -439,7 +466,8 @@ async def process_outputs_socket(): else: outputs_queue.put_nowait(outputs) - self.queue_task = asyncio.create_task(process_outputs_socket()) + self.queue_task = asyncio.create_task(process_outputs_socket(), + name="EngineCoreOutputQueueTask") async def get_output_async(self) -> EngineCoreOutputs: if self.outputs_queue is None: From c8525f06fcc99dbc3564eaed9985edf16d287b47 Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Tue, 4 Mar 2025 15:11:33 +0000 Subject: [PATCH 26/33] [V0][Metrics] Deprecate some questionable request time metrics (#14135) Signed-off-by: Mark McLoughlin --- vllm/engine/metrics.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 08be2cbc0b9d5..70f36d1290ca3 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -197,24 +197,35 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): "Histogram of time spent in DECODE phase for request.", labelnames=labelnames, buckets=request_latency_buckets) + # Deprecated in 0.8 - duplicates vllm:request_queue_time_seconds: + # TODO: in 0.9, only enable if show_hidden_metrics=True self.histogram_time_in_queue_request = self._histogram_cls( name="vllm:time_in_queue_requests", - documentation= - "Histogram of time the request spent in the queue in seconds.", + documentation=( + "Histogram of time the request spent in the queue in seconds. " + "DEPRECATED: use vllm:request_queue_time_seconds instead."), labelnames=labelnames, buckets=request_latency_buckets) + + # Deprecated in 0.8 - use prefill/decode/inference time metrics + # TODO: in 0.9, only enable if show_hidden_metrics=True self.histogram_model_forward_time_request = self._histogram_cls( name="vllm:model_forward_time_milliseconds", - documentation= - "Histogram of time spent in the model forward pass in ms.", + documentation=( + "Histogram of time spent in the model forward pass in ms. " + "DEPRECATED: use prefill/decode/inference time metrics instead." + ), labelnames=labelnames, buckets=build_1_2_3_5_8_buckets(3000)) self.histogram_model_execute_time_request = self._histogram_cls( name="vllm:model_execute_time_milliseconds", - documentation= - "Histogram of time spent in the model execute function in ms.", + documentation=( + "Histogram of time spent in the model execute function in ms." + "DEPRECATED: use prefill/decode/inference time metrics instead." + ), labelnames=labelnames, buckets=build_1_2_3_5_8_buckets(3000)) + # Metadata self.histogram_num_prompt_tokens_request = self._histogram_cls( name="vllm:request_prompt_tokens", From b3cf368d79b4409120ab90f2096d254f52a43fcc Mon Sep 17 00:00:00 2001 From: lkchen Date: Tue, 4 Mar 2025 07:43:59 -0800 Subject: [PATCH 27/33] [V1][Molmo] Fix get_multimodal_embeddings() in molmo.py (#14161) --- examples/offline_inference/vision_language.py | 294 +++++++++++------- vllm/model_executor/models/aria.py | 4 +- vllm/model_executor/models/blip2.py | 4 +- vllm/model_executor/models/chameleon.py | 4 +- vllm/model_executor/models/deepseek_vl2.py | 4 +- vllm/model_executor/models/florence2.py | 4 +- vllm/model_executor/models/fuyu.py | 6 +- vllm/model_executor/models/glm4v.py | 4 +- vllm/model_executor/models/idefics3.py | 4 +- vllm/model_executor/models/interfaces.py | 18 +- vllm/model_executor/models/internvl.py | 4 +- vllm/model_executor/models/llava.py | 4 +- vllm/model_executor/models/llava_next.py | 4 +- .../model_executor/models/llava_next_video.py | 4 +- vllm/model_executor/models/molmo.py | 9 +- vllm/model_executor/models/paligemma.py | 4 +- vllm/model_executor/models/phi3v.py | 4 +- vllm/model_executor/models/pixtral.py | 4 +- vllm/model_executor/models/qwen2_audio.py | 4 +- vllm/model_executor/models/qwen_vl.py | 4 +- vllm/model_executor/models/ultravox.py | 4 +- vllm/model_executor/models/whisper.py | 4 +- 22 files changed, 249 insertions(+), 150 deletions(-) diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index e2ec36211b863..a0a71f18ed949 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -21,7 +21,7 @@ # Aria -def run_aria(question: str, modality: str): +def run_aria(questions: list[str], modality: str): assert modality == "image" model_name = "rhymes-ai/Aria" @@ -32,41 +32,42 @@ def run_aria(question: str, modality: str): dtype="bfloat16", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = (f"<|im_start|>user\n<|img|>{question}" - "<|im_end|>\n<|im_start|>assistant\n") + prompts = [(f"<|im_start|>user\n<|img|>{question}" + "<|im_end|>\n<|im_start|>assistant\n") + for question in questions] stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # BLIP-2 -def run_blip2(question: str, modality: str): +def run_blip2(questions: list[str], modality: str): assert modality == "image" # BLIP-2 prompt format is inaccurate on HuggingFace model repository. # See https://huggingface.co/Salesforce/blip2-opt-2.7b/discussions/15#64ff02f3f8cf9e4f5b038262 #noqa - prompt = f"Question: {question} Answer:" + prompts = [f"Question: {question} Answer:" for question in questions] llm = LLM(model="Salesforce/blip2-opt-2.7b", disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Chameleon -def run_chameleon(question: str, modality: str): +def run_chameleon(questions: list[str], modality: str): assert modality == "image" - prompt = f"{question}" + prompts = [f"{question}" for question in questions] llm = LLM(model="facebook/chameleon-7b", max_model_len=4096, max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Deepseek-VL2 -def run_deepseek_vl2(question: str, modality: str): +def run_deepseek_vl2(questions: list[str], modality: str): assert modality == "image" model_name = "deepseek-ai/deepseek-vl2-tiny" @@ -77,9 +78,12 @@ def run_deepseek_vl2(question: str, modality: str): disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}) - prompt = f"<|User|>: \n{question}\n\n<|Assistant|>:" + prompts = [ + f"<|User|>: \n{question}\n\n<|Assistant|>:" + for question in questions + ] stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Florence2 @@ -99,20 +103,20 @@ def run_florence2(question: str, modality: str): # Fuyu -def run_fuyu(question: str, modality: str): +def run_fuyu(questions: list[str], modality: str): assert modality == "image" - prompt = f"{question}\n" + prompts = [f"{question}\n" for question in questions] llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # GLM-4v -def run_glm4v(question: str, modality: str): +def run_glm4v(questions: list[str], modality: str): assert modality == "image" model_name = "THUDM/glm-4v-9b" @@ -124,15 +128,17 @@ def run_glm4v(question: str, modality: str): hf_overrides={"architectures": ["GLM4VForCausalLM"]}, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) - prompt = f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ - {question}<|assistant|>" + prompts = [ + f"<|user|>\n<|begin_of_image|><|endoftext|><|end_of_image|>\ + {question}<|assistant|>" for question in questions + ] stop_token_ids = [151329, 151336, 151338] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # H2OVL-Mississippi -def run_h2ovl(question: str, modality: str): +def run_h2ovl(questions: list[str], modality: str): assert modality == "image" model_name = "h2oai/h2ovl-mississippi-800m" @@ -146,19 +152,24 @@ def run_h2ovl(question: str, modality: str): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - messages = [{'role': 'user', 'content': f"\n{question}"}] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompts = [ + tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + for question in questions + ] # Stop tokens for H2OVL-Mississippi # https://huggingface.co/h2oai/h2ovl-mississippi-800m stop_token_ids = [tokenizer.eos_token_id] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # Idefics3-8B-Llama3 -def run_idefics3(question: str, modality: str): +def run_idefics3(questions: list[str], modality: str): assert modality == "image" model_name = "HuggingFaceM4/Idefics3-8B-Llama3" @@ -176,15 +187,15 @@ def run_idefics3(question: str, modality: str): }, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, ) - prompt = ( + prompts = [( f"<|begin_of_text|>User:{question}\nAssistant:" - ) + ) for question in questions] stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # InternVL -def run_internvl(question: str, modality: str): +def run_internvl(questions: list[str], modality: str): assert modality == "image" model_name = "OpenGVLab/InternVL2-2B" @@ -198,10 +209,15 @@ def run_internvl(question: str, modality: str): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) - messages = [{'role': 'user', 'content': f"\n{question}"}] - prompt = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) + prompts = [ + tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + for question in questions + ] # Stop tokens for InternVL # models variants may have different stop tokens @@ -209,71 +225,82 @@ def run_internvl(question: str, modality: str): # https://huggingface.co/OpenGVLab/InternVL2-2B/blob/main/conversation.py stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # LLaVA-1.5 -def run_llava(question: str, modality: str): +def run_llava(questions: list[str], modality: str): assert modality == "image" - prompt = f"USER: \n{question}\nASSISTANT:" + prompts = [ + f"USER: \n{question}\nASSISTANT:" for question in questions + ] llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # LLaVA-1.6/LLaVA-NeXT -def run_llava_next(question: str, modality: str): +def run_llava_next(questions: list[str], modality: str): assert modality == "image" - prompt = f"[INST] \n{question} [/INST]" + prompts = [f"[INST] \n{question} [/INST]" for question in questions] llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192, disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache) stop_token_ids = None - return llm, prompt, stop_token_ids + return llm, prompts, stop_token_ids # LlaVA-NeXT-Video # Currently only support for video input -def run_llava_next_video(question: str, modality: str): +def run_llava_next_video(questions: list[str], modality: str): assert modality == "video" - prompt = f"USER: