From 64cdd323ac492602eba2afdbc54b88471987b87e Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Wed, 2 Oct 2024 22:57:38 -0500 Subject: [PATCH 1/6] llama3.2 + cross attn test --- tests/kernels/test_encoder_decoder_attn.py | 5 +- tests/kernels/utils.py | 9 +- vllm/attention/backends/rocm_flash_attn.py | 316 +++++++++++++++------ vllm/envs.py | 5 + vllm/model_executor/layers/linear.py | 8 +- vllm/worker/enc_dec_model_runner.py | 17 +- 6 files changed, 269 insertions(+), 91 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b550a7fdd84f0..c8fefbff76eb3 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -21,7 +21,8 @@ from vllm.utils import is_hip # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] if not is_hip() \ + else [_Backend.ROCM_FLASH] HEAD_SIZES = [64, 256] @@ -807,7 +808,7 @@ def test_encoder_only( assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) +# @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 08004efe9e2f8..9f9e92f54080a 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,7 +13,7 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad) + STR_ROCM_FLASH_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -524,8 +524,13 @@ def make_backend(backend_name: str) -> AttentionBackend: if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() + + if backend_name == STR_ROCM_FLASH_ATTN_VAL: + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b793ebf46d173..c9774dc929f10 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type +from tests.kernels.utils import ref_masked_attention import torch import vllm.envs as envs @@ -86,6 +87,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -96,32 +108,38 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # |-- query_len ---| # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int + max_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + seq_start_loc: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + @property def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: if self.num_prefills == 0: @@ -132,10 +150,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None assert self.block_tables is not None - assert self.seq_start_loc is not None self._cached_prefill_metadata = ROCmFlashAttentionMetadata( num_prefills=self.num_prefills, @@ -147,11 +162,17 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None else self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables ) return self._cached_prefill_metadata @@ -180,6 +201,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables ) return self._cached_decode_metadata @@ -274,6 +301,76 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor, return attn_biases +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + ''' + + + if attn_type == AttentionType.ENCODER: + query_seq_start_loc = torch.tensor([sum(attn_metadata.encoder_seq_lens[:i]) + for i in range(len(attn_metadata.encoder_seq_lens) + 1)], + device = attn_metadata.encoder_seq_lens_tensor.device, + dtype = attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + query_seq_start_loc = torch.tensor([sum(attn_metadata.seq_lens[:i]) + for i in range(len(attn_metadata.seq_lens) + 1)], + device = attn_metadata.seq_lens_tensor.device, + dtype = attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, + query_seq_start_loc, max_seq_len, + attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + query_start_loc = torch.tensor([sum(attn_metadata.seq_lens[:i]) + for i in range(len(attn_metadata.seq_lens) + 1)], + device = attn_metadata.encoder_seq_lens_tensor.device, + dtype = attn_metadata.encoder_seq_lens_tensor.dtype) + + key_seq_start_loc = torch.tensor([sum(attn_metadata.encoder_seq_lens[:i]) + for i in range(len(attn_metadata.encoder_seq_lens) + 1)], + device = attn_metadata.seq_lens_tensor.device, + dtype = attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -377,7 +474,7 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: None, :].expand(tokens, n_kv_heads, n_rep, head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - + def forward( self, query: torch.Tensor, @@ -391,64 +488,109 @@ def forward( ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") - - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None - if kv_cache is not None: + if attn_type != AttentionType.ENCODER and kv_cache is not None: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - k_scale, - v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) + if key is not None and value is not None: + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + if attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] + # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: + output = torch.empty_like(query) + ( + query_seq_start_loc, + query_max_seq_len, + key_seq_start_loc, + key_max_seq_len, + seq_lens, + causal_mask + ) = _get_seq_len_block_table_args(prefill_meta, attn_type) + # Prompt run. - assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -459,18 +601,18 @@ def forward( attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, + seq_lens, make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, value, None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.max_prefill_seq_len, - True, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, self.scale, attn_masks[0][None] if attn_masks is not None else None, @@ -494,11 +636,12 @@ def forward( query, key, value, - prefill_meta.seq_lens, + query_seq_start_loc, num_tokens, self.num_heads, self.head_size, self.scale, + causal_mask, attn_masks, ) else: @@ -506,10 +649,10 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -545,14 +688,18 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. # Whether to use rocm custom paged attention or not + output = torch.empty_like(decode_query) num_seqs, num_heads, head_size = decode_query.shape + num_tokens = num_seqs block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads use_custom = _use_rocm_custom_paged_attention( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len) if use_custom: - max_seq_len = decode_meta.max_decode_seq_len + max_seq_len = (decode_meta.max_decode_seq_len + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len) max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) @@ -573,7 +720,7 @@ def forward( else: out = output ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -582,8 +729,12 @@ def forward( value_cache, self.num_kv_heads, self.scale, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, + decode_meta.block_tables + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -596,9 +747,15 @@ def forward( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.block_tables + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -608,7 +765,7 @@ def forward( ) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _sdpa_attention( @@ -620,6 +777,7 @@ def _sdpa_attention( num_heads: int, head_size: int, scale: float, + is_causal: bool, attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 @@ -637,7 +795,7 @@ def _sdpa_attention( key[:, start:end, :], value[:, start:end, :], dropout_p=0.0, - is_causal=attn_masks is None, + is_causal=is_causal, attn_mask=attn_masks[i] if attn_masks else None, scale=scale).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out @@ -653,4 +811,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) \ No newline at end of file diff --git a/vllm/envs.py b/vllm/envs.py index ee4711dbec842..afae779fc80b2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,6 +74,7 @@ VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = False VLLM_FP8_PADDING: bool = True + VLLM_NO_TUNED_GEMM: bool = False def get_default_cache_root(): @@ -486,6 +487,10 @@ def get_default_config_root(): # Pad the weight for moe kernel or not "VLLM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), + + # Not supported by mllama3.2 + "VLLM_NO_TUNED_GEMM": + lambda: bool(int(os.getenv("VLLM_NO_TUNED_GEMM", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fce41b92d2fac..7fd78c71346c5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -20,6 +20,9 @@ RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs +import torch.nn.functional as F +import vllm.envs as envs + logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ @@ -131,7 +134,10 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return tgemm.mm(x, layer.weight, bias) + if envs.VLLM_NO_TUNED_GEMM: + return F.linear(x, layer.weight, bias) + else: + return tgemm.mm(x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bd716ac3e7ec3..b7c980e542197 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -23,7 +23,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) -from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad +from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, is_hip, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, @@ -120,7 +120,7 @@ def __init__( def _maybe_force_supported_attention_backend(self): ''' - Force vLLM to use the XFormers attention backend, + Force vLLM to use the XFormers or ROCM attention backend, which is currently the only supported option. ''' @@ -138,18 +138,21 @@ def raise_backend_err(): # The user has not already specified an attention backend # override logger.info("EncoderDecoderModelRunner requires " - "XFormers backend; overriding backend " - "auto-selection and forcing XFormers.") - global_force_attn_backend(_Backend.XFORMERS) + "XFormers or ROCM backend; overriding backend " + "auto-selection and forcing XFormers or ROCM.") + global_force_attn_backend(_Backend.ROCM_FLASH + if is_hip() else _Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS: + if maybe_global_forced_backend != _Backend.XFORMERS and \ + maybe_global_forced_backend != _Backend.ROCM_FLASH: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS: + if maybe_env_var_forced_backend != _Backend.XFORMERS and \ + maybe_global_forced_backend != _Backend.ROCM_FLASH: raise_backend_err() def _list_to_int32_tensor( From caf095bb96e2a020d2b8e7deb01a0a920639afa2 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 3 Oct 2024 17:20:23 -0500 Subject: [PATCH 2/6] lint issues fix --- tests/kernels/utils.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 67 ++++++++++++---------- vllm/model_executor/layers/linear.py | 5 +- vllm/worker/enc_dec_model_runner.py | 3 +- 4 files changed, 42 insertions(+), 37 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 9f9e92f54080a..8aec937aaba9b 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -12,8 +12,8 @@ from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - STR_ROCM_FLASH_ATTN_VAL, make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_ROCM_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c9774dc929f10..9f351b404f6b2 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type -from tests.kernels.utils import ref_masked_attention import torch import vllm.envs as envs @@ -162,9 +161,12 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=None if self.query_start_loc is None else self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None if self.seq_start_loc is None else self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=None if self.context_lens_tensor is None else self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None + else self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None + else self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None + else self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, # Begin encoder & cross attn fields below... @@ -326,12 +328,13 @@ def _get_seq_len_block_table_args( * Appropriate max sequence-length scalar ''' - + partial_prefix_sum = 0 if attn_type == AttentionType.ENCODER: - query_seq_start_loc = torch.tensor([sum(attn_metadata.encoder_seq_lens[:i]) - for i in range(len(attn_metadata.encoder_seq_lens) + 1)], - device = attn_metadata.encoder_seq_lens_tensor.device, - dtype = attn_metadata.encoder_seq_lens_tensor.dtype) + query_seq_start_loc = torch.tensor( + [0] + [partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens], + device = attn_metadata.encoder_seq_lens_tensor.device, + dtype = attn_metadata.encoder_seq_lens_tensor.dtype) causal_mask = False # No block tables associated with encoder attention @@ -341,10 +344,11 @@ def _get_seq_len_block_table_args( elif attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run - query_seq_start_loc = torch.tensor([sum(attn_metadata.seq_lens[:i]) - for i in range(len(attn_metadata.seq_lens) + 1)], - device = attn_metadata.seq_lens_tensor.device, - dtype = attn_metadata.seq_lens_tensor.dtype) + query_seq_start_loc = torch.tensor( + [0] + [partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens], + device = attn_metadata.seq_lens_tensor.device, + dtype = attn_metadata.seq_lens_tensor.dtype) max_seq_len = attn_metadata.max_prefill_seq_len causal_mask = True @@ -352,15 +356,18 @@ def _get_seq_len_block_table_args( query_seq_start_loc, max_seq_len, attn_metadata.seq_lens, causal_mask) elif attn_type == AttentionType.ENCODER_DECODER: - query_start_loc = torch.tensor([sum(attn_metadata.seq_lens[:i]) - for i in range(len(attn_metadata.seq_lens) + 1)], - device = attn_metadata.encoder_seq_lens_tensor.device, - dtype = attn_metadata.encoder_seq_lens_tensor.dtype) - - key_seq_start_loc = torch.tensor([sum(attn_metadata.encoder_seq_lens[:i]) - for i in range(len(attn_metadata.encoder_seq_lens) + 1)], - device = attn_metadata.seq_lens_tensor.device, - dtype = attn_metadata.seq_lens_tensor.dtype) + query_start_loc = torch.tensor( + [0] + [partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens], + device = attn_metadata.encoder_seq_lens_tensor.device, + dtype = attn_metadata.encoder_seq_lens_tensor.dtype) + + partial_prefix_sum = 0 + key_seq_start_loc = torch.tensor( + [0] + [partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens], + device = attn_metadata.seq_lens_tensor.device, + dtype = attn_metadata.seq_lens_tensor.dtype) causal_mask = False # Enc/dec cross-attention KVs match encoder sequence length; @@ -491,8 +498,8 @@ def forward( For decoder-only models: query, key and value must be non-None. For encoder/decoder models: - * ROCmFlashAttentionImpl.forward() may be invoked for both self- and cross- - attention layers. + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. * For self-attention: query, key and value must be non-None. * For cross-attention: * Query must be non-None @@ -546,9 +553,10 @@ def forward( kv_cache, self.num_kv_heads, self.head_size) if key is not None and value is not None: - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. PagedAttention.write_to_paged_cache( key, value, @@ -564,11 +572,9 @@ def forward( if attn_type != AttentionType.ENCODER: num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens else: assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens - num_decode_tokens = 0 # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] @@ -637,7 +643,7 @@ def forward( key, value, query_seq_start_loc, - num_tokens, + num_prefill_tokens, self.num_heads, self.head_size, self.scale, @@ -690,7 +696,6 @@ def forward( # Whether to use rocm custom paged attention or not output = torch.empty_like(decode_query) num_seqs, num_heads, head_size = decode_query.shape - num_tokens = num_seqs block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads use_custom = _use_rocm_custom_paged_attention( diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7fd78c71346c5..01626236cc5a9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,8 +2,10 @@ from typing import Dict, List, Optional, Tuple import torch +import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter +import vllm.envs as envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -20,9 +22,6 @@ RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs -import torch.nn.functional as F -import vllm.envs as envs - logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index b7c980e542197..bb5fe19f1aae6 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -23,7 +23,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) -from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, is_hip, make_tensor_with_pad +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND, is_hip, + make_tensor_with_pad) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, From 4b54363232a3756bbf01ec555d9c55473b6f3384 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Thu, 3 Oct 2024 17:58:48 -0500 Subject: [PATCH 3/6] mypy errors --- tests/kernels/test_encoder_decoder_attn.py | 1 - vllm/attention/backends/rocm_flash_attn.py | 11 ++++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index c8fefbff76eb3..f9b15bfb02605 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -808,7 +808,6 @@ def test_encoder_only( assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) -# @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9f351b404f6b2..af8c7c3577940 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -330,6 +330,8 @@ def _get_seq_len_block_table_args( partial_prefix_sum = 0 if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None query_seq_start_loc = torch.tensor( [0] + [partial_prefix_sum := partial_prefix_sum + i for i in attn_metadata.encoder_seq_lens], @@ -344,6 +346,8 @@ def _get_seq_len_block_table_args( elif attn_type == AttentionType.DECODER: # Decoder self-attention # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None query_seq_start_loc = torch.tensor( [0] + [partial_prefix_sum := partial_prefix_sum + i for i in attn_metadata.seq_lens], @@ -356,6 +360,8 @@ def _get_seq_len_block_table_args( query_seq_start_loc, max_seq_len, attn_metadata.seq_lens, causal_mask) elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None query_start_loc = torch.tensor( [0] + [partial_prefix_sum := partial_prefix_sum + i for i in attn_metadata.seq_lens], @@ -363,6 +369,8 @@ def _get_seq_len_block_table_args( dtype = attn_metadata.encoder_seq_lens_tensor.dtype) partial_prefix_sum = 0 + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None key_seq_start_loc = torch.tensor( [0] + [partial_prefix_sum := partial_prefix_sum + i for i in attn_metadata.encoder_seq_lens], @@ -481,7 +489,7 @@ def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: None, :].expand(tokens, n_kv_heads, n_rep, head_dim).reshape(tokens, n_kv_heads * n_rep, head_dim)) - + def forward( self, query: torch.Tensor, @@ -705,6 +713,7 @@ def forward( max_seq_len = (decode_meta.max_decode_seq_len if attn_type != AttentionType.ENCODER_DECODER else decode_meta.max_encoder_seq_len) + assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) From 34d265808950aaecc5b33ea1ad132e3146ea2602 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Fri, 4 Oct 2024 14:20:32 -0500 Subject: [PATCH 4/6] making yapf happy --- tests/kernels/utils.py | 6 +- vllm/attention/backends/rocm_flash_attn.py | 114 +++++++++++---------- vllm/worker/enc_dec_model_runner.py | 4 +- 3 files changed, 63 insertions(+), 61 deletions(-) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 8aec937aaba9b..d1de0b20be2f7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -525,12 +525,12 @@ def make_backend(backend_name: str) -> AttentionBackend: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend return XFormersBackend() - - if backend_name == STR_ROCM_FLASH_ATTN_VAL: + + if backend_name == STR_ROCM_FLASH_ATTN_VAL: from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 ROCmFlashAttentionBackend) return ROCmFlashAttentionBackend - + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index af8c7c3577940..417dbc6d1483c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -138,7 +138,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # and block tables cross_slot_mapping: Optional[torch.Tensor] = None cross_block_tables: Optional[torch.Tensor] = None - + @property def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: if self.num_prefills == 0: @@ -161,12 +161,12 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=None if self.query_start_loc is None - else self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=None if self.seq_start_loc is None - else self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=None if self.context_lens_tensor is None - else self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, # Begin encoder & cross attn fields below... @@ -174,8 +174,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables - ) + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -208,8 +207,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, max_encoder_seq_len=self.max_encoder_seq_len, cross_slot_mapping=self.cross_slot_mapping, - cross_block_tables=self.cross_block_tables - ) + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -333,10 +331,12 @@ def _get_seq_len_block_table_args( assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.encoder_seq_lens_tensor is not None query_seq_start_loc = torch.tensor( - [0] + [partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens], - device = attn_metadata.encoder_seq_lens_tensor.device, - dtype = attn_metadata.encoder_seq_lens_tensor.dtype) + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) causal_mask = False # No block tables associated with encoder attention @@ -349,33 +349,38 @@ def _get_seq_len_block_table_args( assert attn_metadata.seq_lens is not None assert attn_metadata.seq_lens_tensor is not None query_seq_start_loc = torch.tensor( - [0] + [partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens], - device = attn_metadata.seq_lens_tensor.device, - dtype = attn_metadata.seq_lens_tensor.dtype) + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) max_seq_len = attn_metadata.max_prefill_seq_len causal_mask = True - return (query_seq_start_loc, max_seq_len, - query_seq_start_loc, max_seq_len, - attn_metadata.seq_lens, causal_mask) + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) elif attn_type == AttentionType.ENCODER_DECODER: assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens_tensor is not None query_start_loc = torch.tensor( - [0] + [partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.seq_lens], - device = attn_metadata.encoder_seq_lens_tensor.device, - dtype = attn_metadata.encoder_seq_lens_tensor.dtype) + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) partial_prefix_sum = 0 assert attn_metadata.encoder_seq_lens is not None assert attn_metadata.seq_lens_tensor is not None key_seq_start_loc = torch.tensor( - [0] + [partial_prefix_sum := partial_prefix_sum + i - for i in attn_metadata.encoder_seq_lens], - device = attn_metadata.seq_lens_tensor.device, - dtype = attn_metadata.seq_lens_tensor.dtype) + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) causal_mask = False # Enc/dec cross-attention KVs match encoder sequence length; @@ -386,6 +391,7 @@ def _get_seq_len_block_table_args( else: raise AttributeError(f"Invalid attention type {str(attn_type)}") + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -547,7 +553,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - + query = query.view(-1, self.num_heads, self.head_size) if key is not None: assert value is not None @@ -561,9 +567,9 @@ def forward( kv_cache, self.num_kv_heads, self.head_size) if key is not None and value is not None: - # Reshape the input keys and values and store them in the + # Reshape the input keys and values and store them in the # cache. If kv_cache is not provided, the new key and value - # tensors are not cached. This happens during the initial + # tensors are not cached. This happens during the initial # memory profiling run. PagedAttention.write_to_paged_cache( key, @@ -571,8 +577,8 @@ def forward( key_cache, value_cache, attn_metadata.slot_mapping - if attn_type != AttentionType.ENCODER_DECODER else - attn_metadata.cross_slot_mapping, + if attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, self.kv_cache_dtype, k_scale, v_scale, @@ -595,14 +601,10 @@ def forward( if prefill_meta := attn_metadata.prefill_metadata: output = torch.empty_like(query) - ( - query_seq_start_loc, - query_max_seq_len, - key_seq_start_loc, - key_max_seq_len, - seq_lens, - causal_mask - ) = _get_seq_len_block_table_args(prefill_meta, attn_type) + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, attn_type) # Prompt run. if kv_cache is None or prefill_meta.block_tables.numel() == 0: @@ -711,8 +713,8 @@ def forward( decode_meta.max_decode_seq_len) if use_custom: max_seq_len = (decode_meta.max_decode_seq_len - if attn_type != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len) + if attn_type != AttentionType.ENCODER_DECODER + else decode_meta.max_encoder_seq_len) assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // @@ -744,11 +746,11 @@ def forward( self.num_kv_heads, self.scale, decode_meta.block_tables - if attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, decode_meta.seq_lens_tensor - if attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -762,14 +764,14 @@ def forward( key_cache, value_cache, decode_meta.block_tables - if attn_type != AttentionType.ENCODER_DECODER else - decode_meta.cross_block_tables, + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, decode_meta.seq_lens_tensor - if attn_type != AttentionType.ENCODER_DECODER else - decode_meta.encoder_seq_lens_tensor, + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, decode_meta.max_decode_seq_len - if attn_type != AttentionType.ENCODER_DECODER else - decode_meta.max_encoder_seq_len, + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -825,4 +827,4 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) \ No newline at end of file + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bb5fe19f1aae6..4606866bdba52 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -141,8 +141,8 @@ def raise_backend_err(): logger.info("EncoderDecoderModelRunner requires " "XFormers or ROCM backend; overriding backend " "auto-selection and forcing XFormers or ROCM.") - global_force_attn_backend(_Backend.ROCM_FLASH - if is_hip() else _Backend.XFORMERS) + global_force_attn_backend( + _Backend.ROCM_FLASH if is_hip() else _Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. From 9ed31a8fd95585d49755ee8062dc395642997f3e Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Fri, 4 Oct 2024 16:36:27 -0500 Subject: [PATCH 5/6] cut off WA for tunned gemms --- vllm/envs.py | 5 ----- vllm/model_executor/layers/linear.py | 7 +------ vllm/model_executor/layers/tuned_gemm.py | 2 +- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index afae779fc80b2..ee4711dbec842 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -74,7 +74,6 @@ VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = False VLLM_FP8_PADDING: bool = True - VLLM_NO_TUNED_GEMM: bool = False def get_default_cache_root(): @@ -487,10 +486,6 @@ def get_default_config_root(): # Pad the weight for moe kernel or not "VLLM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), - - # Not supported by mllama3.2 - "VLLM_NO_TUNED_GEMM": - lambda: bool(int(os.getenv("VLLM_NO_TUNED_GEMM", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 01626236cc5a9..fce41b92d2fac 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,10 +2,8 @@ from typing import Dict, List, Optional, Tuple import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -import vllm.envs as envs from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, @@ -133,10 +131,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - if envs.VLLM_NO_TUNED_GEMM: - return F.linear(x, layer.weight, bias) - else: - return tgemm.mm(x, layer.weight, bias) + return tgemm.mm(x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 7ea1d8d93ea2b..fc58c8c6d379e 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -84,7 +84,7 @@ def mm(self, inp, weights, bias=None): # uses this for linear units. However, sampler # will use torch.matmul with 2 dimensions only if inp.dim() == 3: - inp_view = inp.view(-1, inp.size(-1)) + inp_view = inp.reshape(-1, inp.size(-1)) batched = True else: inp_view = inp From cfe23d8441ea92e0ff55901fa9b364bf570e7968 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev Date: Fri, 4 Oct 2024 17:55:26 -0500 Subject: [PATCH 6/6] try and catch for non continuous tensor --- vllm/model_executor/layers/tuned_gemm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index fc58c8c6d379e..f765b8c39fa6c 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -84,8 +84,11 @@ def mm(self, inp, weights, bias=None): # uses this for linear units. However, sampler # will use torch.matmul with 2 dimensions only if inp.dim() == 3: - inp_view = inp.reshape(-1, inp.size(-1)) - batched = True + try: + inp_view = inp.view(-1, inp.size(-1)) + batched = True + except RuntimeError: + return F.linear(inp, weights, bias) else: inp_view = inp batched = False