From ba59b78a9c5a8e9a535a0be081a438951679347f Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 13 Feb 2025 22:21:50 -0800 Subject: [PATCH] [ROCm][V1] Add intial ROCm support to V1 (#12790) --- requirements-rocm-build.txt | 16 ++ vllm/attention/ops/prefix_prefill.py | 6 +- vllm/platforms/rocm.py | 45 ++++-- vllm/v1/attention/backends/flash_attn.py | 5 +- vllm/v1/attention/backends/rocm_attn.py | 182 +++++++++++++++++++++++ 5 files changed, 236 insertions(+), 18 deletions(-) create mode 100644 requirements-rocm-build.txt create mode 100644 vllm/v1/attention/backends/rocm_attn.py diff --git a/requirements-rocm-build.txt b/requirements-rocm-build.txt new file mode 100644 index 0000000000000..00ae0340fc529 --- /dev/null +++ b/requirements-rocm-build.txt @@ -0,0 +1,16 @@ +# Common dependencies +-r requirements-common.txt + +--extra-index-url https://download.pytorch.org/whl/rocm6.2 +torch==2.5.1 +torchvision==0.20.1 +torchaudio==2.5.1 + +cmake>=3.26 +ninja +packaging +setuptools>=61 +setuptools-scm>=8 +wheel +jinja2 +amdsmi==6.2.4 diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 5fca1639363e0..362c46a95f322 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -718,7 +718,8 @@ def context_attention_fwd(q, k_scale: torch.Tensor, v_scale: torch.Tensor, alibi_slopes=None, - sliding_window=None): + sliding_window=None, + sm_scale=None): q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 @@ -759,7 +760,8 @@ def context_attention_fwd(q, # round up Lk to a power of 2 - this is required for Triton block size Lk_padded = triton.next_power_of_2(Lk) - sm_scale = 1.0 / (Lq**0.5) + if sm_scale is None: + sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1] diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 13aebc605af74..d57cce4231dc0 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -import os from functools import lru_cache from typing import TYPE_CHECKING, Dict, List, Optional @@ -29,12 +28,6 @@ except ImportError as e: logger.warning("Failed to import from vllm._rocm_C with %r", e) -if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: - logger.warning("`fork` method is not supported by ROCm. " - "VLLM_WORKER_MULTIPROC_METHOD is overridden to" - " `spawn` instead.") - os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: List[str] = [] @@ -84,6 +77,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, return "vllm.attention.backends.triton_mla.TritonMLABackend" selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) + if envs.VLLM_USE_V1: + logger.info("Using ROCm Attention backend on V1 engine.") + return "vllm.v1.attention.backends.rocm_attn.ROCmAttentionBackend" if selected_backend == _Backend.ROCM_FLASH: if not cls.has_device_capability(90): # not Instinct series GPUs. @@ -102,7 +98,11 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @classmethod @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - return torch.cuda.get_device_name(device_id) + # NOTE: When using V1 this function is called when overriding the + # engine args. Calling torch.cuda.get_device_name(device_id) here + # will result in the ROCm context being initialized before other + # processes can be created. + return "AMD" @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: @@ -129,15 +129,30 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: scheduler_config = vllm_config.scheduler_config if parallel_config.worker_cls == "auto": if scheduler_config.is_multi_step: - parallel_config.worker_cls = \ - "vllm.worker.multi_step_worker.MultiStepWorker" + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Multi-step scheduling is not supported (and not " + "needed) on VLLM V1. Please launch without " + "--num-scheduler-steps.") + else: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" elif vllm_config.speculative_config: - parallel_config.worker_cls = \ - "vllm.spec_decode.spec_decode_worker.create_spec_worker" - parallel_config.sd_worker_cls = \ - "vllm.worker.worker.Worker" + if envs.VLLM_USE_V1: + raise NotImplementedError( + "Speculative decoding is not yet supported on VLLM V1." + ) + else: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + parallel_config.sd_worker_cls = \ + "vllm.worker.worker.Worker" else: - parallel_config.worker_cls = "vllm.worker.worker.Worker" + if envs.VLLM_USE_V1: + parallel_config.worker_cls = \ + "vllm.v1.worker.gpu_worker.Worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" @classmethod def verify_model_arch(cls, model_arch: str) -> None: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 5cb1e2fd26a5c..b1b5cc359251a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -12,8 +12,11 @@ AttentionMetadata, AttentionType) from vllm.attention.backends.utils import get_flash_attn_version from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.vllm_flash_attn import flash_attn_varlen_func + +if current_platform.is_cuda(): + from vllm.vllm_flash_attn import flash_attn_varlen_func logger = init_logger(__name__) diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py new file mode 100644 index 0000000000000..5f3eb37514d85 --- /dev/null +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Attention layer with PagedAttention on rocm""" +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.ops.prefix_prefill import context_attention_fwd +from vllm.logger import init_logger +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata + +logger = init_logger(__name__) + + +class ROCmAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "ROCM_ATTN_VLLM_V1" + + @staticmethod + def get_impl_cls() -> Type["ROCmAttentionImpl"]: + return ROCmAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def use_cascade_attention(*args, **kwargs) -> bool: + return False + + +class ROCmAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "ROCmAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = ROCmAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by ROCmAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "ROCmAttentionImpl") + + def forward( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + + if attn_metadata is None: + # Profiling run. + return output + + assert attn_metadata.use_cascade is False + + # IMPORTANT! + # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in + # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead + # in this method. For example, `view` and `slice` (or `[:n]`) operations + # are surprisingly slow even in the case they do not invoke any GPU ops. + # Minimize the PyTorch ops in this method as much as possible. + # Whenever making a change in this method, please benchmark the + # performance to make sure it does not introduce any overhead. + + num_actual_tokens = attn_metadata.num_actual_tokens + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + # TODO(sage): Refactor the context_attention_fwd kernel so that this + # overhead can be removed + context_lens = torch.empty_like(attn_metadata.seq_lens) + batch_size = len(attn_metadata.query_start_loc) - 1 + assert len(context_lens) == batch_size + for i in range(batch_size): + query_start = attn_metadata.query_start_loc[i] + query_end = attn_metadata.query_start_loc[i + 1] + context_lens[i] = attn_metadata.seq_lens[i] - (query_end - + query_start) + + # Compute attention and update output up to `num_actual_tokens`. + context_attention_fwd(q=query[:num_actual_tokens], + k=key[:num_actual_tokens], + v=value[:num_actual_tokens], + o=output[:num_actual_tokens], + kv_cache_dtype=self.kv_cache_dtype, + k_cache=key_cache, + v_cache=value_cache, + b_loc=attn_metadata.block_table, + b_start_loc=attn_metadata.query_start_loc, + b_seq_len=attn_metadata.seq_lens, + b_ctx_len=context_lens, + max_input_len=attn_metadata.max_query_len, + k_scale=layer._k_scale, + v_scale=layer._v_scale, + alibi_slopes=self.alibi_slopes, + sliding_window=self.sliding_window[0], + sm_scale=self.scale) + return output