From d5b47dc3b8a3ed3970bf8462657be20930cd997c Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Fri, 24 Jan 2025 13:37:17 -0600 Subject: [PATCH 1/4] initial commit with rocm fa update Signed-off-by: Aleksandr Malyshev --- vllm/attention/backends/rocm_flash_attn.py | 376 ++++++++++++++++----- vllm/model_executor/models/mllama.py | 63 ++-- 2 files changed, 333 insertions(+), 106 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e9f2808ff1674..09b441ad6093f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -90,6 +90,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -100,30 +111,18 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # |-- query_len ---| # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int + max_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool - + seq_start_loc: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None # Max number of query tokens among request in the batch. max_decode_query_len: Optional[int] = None @@ -131,6 +130,23 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + @property def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: if self.num_prefills == 0: @@ -141,10 +157,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None assert self.block_tables is not None - assert self.seq_start_loc is not None self._cached_prefill_metadata = ROCmFlashAttentionMetadata( num_prefills=self.num_prefills, @@ -158,12 +171,20 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -192,7 +213,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) # Batch may be composed of prefill|decodes, adjust query start indices # to refer to the start of decodes when the two are split apart. # E.g. in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6]. @@ -302,6 +328,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor, return attn_biases +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: str, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + ''' + + partial_prefix_sum = 0 + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + partial_prefix_sum = 0 + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -344,10 +461,13 @@ def __init__( if blocksparse_params is not None: raise ValueError( "ROCmFlashAttention does not support blocksparse attention.") - if logits_soft_cap is not None: - raise ValueError( - "ROCmFlashAttention does not support attention logits soft " - "capping.") + + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + self.logits_soft_cap = 0.0 + else: + self.logits_soft_cap = logits_soft_cap + self.attn_type = attn_type self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -372,6 +492,14 @@ def __init__( # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: + if logits_soft_cap is not None: + raise ValueError( + "ROCm Triton FlashAttention does not support attention" + "logits soft capping." + " please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention @@ -396,14 +524,13 @@ def __init__( self.use_naive_attn = True if self.use_naive_attn: - self.attn_func = _sdpa_attention - logger.debug("Using naive attention in ROCmBackend") + if logits_soft_cap is not None: + raise ValueError( + "ROCm Naive FlashAttention does not support" + "attention logits soft capping.") - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") + self.attn_func = _sdpa_attention + logger.debug("Using naive (SDPA) attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -425,6 +552,37 @@ def forward( ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] @@ -433,54 +591,80 @@ def forward( NOTE: kv_cache will be an empty tensor with shape [0] for profiling run. attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ - # Reminder: Please update docs/source/features/compatibility_matrix.md - # If the feature combo become valid - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None - if kv_cache.numel() > 0: + if self.attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if self.attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + if self.attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + if key is not None and value is not None \ + and self.attn_type != AttentionType.ENCODER_DECODER: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: # Prompt run. - assert prefill_meta.seq_lens is not None + # normal attention and DECODER + if self.attn_type == AttentionType.DECODER and ( + kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = (prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + attn_metadata.seq_lens, True) + # prefix-enabled attention and ENCODER/ENCODER_DECODER + else: + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, self.attn_type) + # Prompt run. if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -491,18 +675,18 @@ def forward( attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, + seq_lens, make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, value, None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.max_prefill_seq_len, - True, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, self.scale, attn_masks[0][None] if attn_masks is not None else None, @@ -526,11 +710,12 @@ def forward( query, key, value, - prefill_meta.seq_lens, - num_tokens, + query_seq_start_loc, + num_prefill_tokens, self.num_heads, self.head_size, self.scale, + causal_mask, attn_masks, ) else: @@ -538,19 +723,23 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) # common code for prefill assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + if output.shape[0] > num_prefill_tokens: + output[:num_prefill_tokens] = out + else: + output = out else: # prefix-enabled attention output[:num_prefill_tokens] = PagedAttention.forward_prefix( @@ -581,7 +770,10 @@ def forward( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len) if use_custom: - max_seq_len = decode_meta.max_decode_seq_len + max_seq_len = (decode_meta.max_decode_seq_len if + self.attn_type != AttentionType.ENCODER_DECODER + else decode_meta.max_encoder_seq_len) + assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) @@ -597,8 +789,12 @@ def forward( device=output.device, ) max_logits = torch.empty_like(exp_sums) + if num_prefill_tokens > 0: + out = output[num_prefill_tokens:] + else: + out = output ops.paged_attention_rocm( - output[num_prefill_tokens:], + out, exp_sums, max_logits, tmp_output, @@ -607,8 +803,12 @@ def forward( value_cache, self.num_kv_heads, self.scale, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -621,9 +821,15 @@ def forward( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.block_tables + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if self.attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -633,7 +839,7 @@ def forward( ) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _sdpa_attention( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 2554281610a30..f5314bcb305e8 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -51,6 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.platforms import current_platform from vllm.sequence import SequenceData from vllm.utils import is_list_of @@ -63,6 +64,8 @@ MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN = "<|image|>" +iteration = 0 +layer = 0 class MllamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -831,34 +834,44 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: - if self.attn.backend in (_Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1): - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) - torch.ops._C_cache_ops.reshape_and_cache_flash( - cached_k, - cached_v, - kv_cache[0], - kv_cache[1], - attn_metadata. - cross_slot_mapping, # type: ignore[union-attr] - "auto", - 1.0, - 1.0, - ) - elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): + i = torch.ones(1, dtype=torch.float32) + if current_platform.is_rocm(): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: - raise ValueError( - f"Unsupported Attention backend {self.attn.backend} " - "enum found. Expected the Attention backend to be " - "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.") + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + torch.ops._C_cache_ops.reshape_and_cache_flash( + cached_k, + cached_v, + kv_cache[0], + kv_cache[1], + attn_metadata. + cross_slot_mapping, # type: ignore[union-attr] + "auto", + 1.0, + 1.0, + ) + elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + PagedAttention.write_to_paged_cache( + cached_k, cached_v, key_cache, value_cache, + attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + else: + raise ValueError( + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a @@ -1451,6 +1464,14 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight, shard_id) break else: + orig_name = name + from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + logger.debug("Missing name %s, orig name %s", name, + orig_name) + continue + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) From d8aa3de4cb1fe12b519ce1ba9a9b74490b294882 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Fri, 24 Jan 2025 14:31:05 -0600 Subject: [PATCH 2/4] linters Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/models/mllama.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index b766e3930cfe5..81ab118f336cb 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -67,6 +67,7 @@ iteration = 0 layer = 0 + class MllamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor @@ -845,9 +846,11 @@ def _attention_with_mask( attn_metadata.cross_slot_mapping, "auto", i, i) else: if self.attn.backend in (_Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1): - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + _Backend.FLASH_ATTN_VLLM_V1): + cached_k = torch.cat( + [k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat( + [v[s:e] for s, e in kv_range_for_decode]) torch.ops._C_cache_ops.reshape_and_cache_flash( cached_k, cached_v, @@ -859,11 +862,15 @@ def _attention_with_mask( i, i, ) - elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): + elif self.attn.backend in (_Backend.XFORMERS, + _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_local_key_value_heads, self.head_dim) - cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + kv_cache, self.num_local_key_value_heads, + self.head_dim) + cached_k = torch.cat( + [k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat( + [v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, attn_metadata.cross_slot_mapping, "auto", i, i) @@ -871,7 +878,8 @@ def _attention_with_mask( raise ValueError( f"Unsupported Attention backend {self.attn.backend} " "enum found. Expected the Attention backend to be " - "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.") + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, " + "XFORMERS or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a @@ -1465,7 +1473,8 @@ def load_weights(self, weights: Iterable[Tuple[str, break else: orig_name = name - from vllm.model_executor.model_loader.weight_utils import maybe_remap_kv_scale_name + from vllm.model_executor.model_loader.weight_utils import ( + maybe_remap_kv_scale_name) name = maybe_remap_kv_scale_name(name, params_dict) if name is None: logger.debug("Missing name %s, orig name %s", name, From f1f62e65f8d0fc3f868545ec19ddea7eb514aa3e Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Mon, 27 Jan 2025 12:56:52 -0600 Subject: [PATCH 3/4] replying to comments Signed-off-by: Aleksandr Malyshev --- vllm/model_executor/models/mllama.py | 67 ++++++++++------------------ 1 file changed, 24 insertions(+), 43 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 81ab118f336cb..1f58fb541ff7c 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -48,10 +48,10 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.platforms import current_platform from vllm.sequence import SequenceData from vllm.utils import is_list_of @@ -64,9 +64,6 @@ MLLAMA_IMAGE_TOKEN_ID = 128256 MLLAMA_IMAGE_TOKEN = "<|image|>" -iteration = 0 -layer = 0 - class MllamaImagePixelInputs(TypedDict): type: Literal["pixel_values"] @@ -836,7 +833,23 @@ def _attention_with_mask( # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: i = torch.ones(1, dtype=torch.float32) - if current_platform.is_rocm(): + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + torch.ops._C_cache_ops.reshape_and_cache_flash( + cached_k, + cached_v, + kv_cache[0], + kv_cache[1], + attn_metadata. + cross_slot_mapping, # type: ignore[union-attr] + "auto", + i, + i, + ) + elif self.attn.backend in (_Backend.XFORMERS, _Backend.ROCM_FLASH, + _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -845,41 +858,11 @@ def _attention_with_mask( cached_k, cached_v, key_cache, value_cache, attn_metadata.cross_slot_mapping, "auto", i, i) else: - if self.attn.backend in (_Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1): - cached_k = torch.cat( - [k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat( - [v[s:e] for s, e in kv_range_for_decode]) - torch.ops._C_cache_ops.reshape_and_cache_flash( - cached_k, - cached_v, - kv_cache[0], - kv_cache[1], - attn_metadata. - cross_slot_mapping, # type: ignore[union-attr] - "auto", - i, - i, - ) - elif self.attn.backend in (_Backend.XFORMERS, - _Backend.TORCH_SDPA): - key_cache, value_cache = PagedAttention.split_kv_cache( - kv_cache, self.num_local_key_value_heads, - self.head_dim) - cached_k = torch.cat( - [k[s:e] for s, e in kv_range_for_decode]) - cached_v = torch.cat( - [v[s:e] for s, e in kv_range_for_decode]) - PagedAttention.write_to_paged_cache( - cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", i, i) - else: - raise ValueError( - f"Unsupported Attention backend {self.attn.backend} " - "enum found. Expected the Attention backend to be " - "FLASH_ATTN, FLASH_ATTN_VLLM_V1, " - "XFORMERS or TORCH_SDPA.") + raise ValueError( + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, " + "XFORMERS or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a @@ -1473,8 +1456,6 @@ def load_weights(self, weights: Iterable[Tuple[str, break else: orig_name = name - from vllm.model_executor.model_loader.weight_utils import ( - maybe_remap_kv_scale_name) name = maybe_remap_kv_scale_name(name, params_dict) if name is None: logger.debug("Missing name %s, orig name %s", name, From e096bc81334c36167e464c0fccca63ba041c7d2d Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Tue, 28 Jan 2025 13:46:02 -0600 Subject: [PATCH 4/4] linter Signed-off-by: Aleksandr Malyshev --- vllm/attention/backends/rocm_flash_attn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index aef795afe00ae..12110ec7356d5 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -772,9 +772,9 @@ def forward( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len) if use_custom: - max_seq_len = (decode_meta.max_decode_seq_len if - self.attn_type != AttentionType.ENCODER_DECODER - else decode_meta.max_encoder_seq_len) + max_seq_len = (decode_meta.max_decode_seq_len if self.attn_type + != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) //