From 2550f14a77c84b93045f4603fdcf3bc310164b15 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Fri, 4 Oct 2024 16:03:09 -0700 Subject: [PATCH] llama3.2 + cross attn test (#220) * llama3.2 + cross attn test * lint issues fix * mypy errors * making yapf happy * cut off WA for tunned gemms * try and catch for non continuous tensor --------- Co-authored-by: Aleksandr Malyshev --- tests/kernels/test_encoder_decoder_attn.py | 4 +- tests/kernels/utils.py | 11 +- vllm/attention/backends/rocm_flash_attn.py | 334 ++++++++++++++++----- vllm/model_executor/layers/tuned_gemm.py | 7 +- vllm/worker/enc_dec_model_runner.py | 18 +- 5 files changed, 280 insertions(+), 94 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b550a7fdd84f0..f9b15bfb02605 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -21,7 +21,8 @@ from vllm.utils import is_hip # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] +LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] if not is_hip() \ + else [_Backend.ROCM_FLASH] HEAD_SIZES = [64, 256] @@ -807,7 +808,6 @@ def test_encoder_only( assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out) -@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 08004efe9e2f8..d1de0b20be2f7 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -12,8 +12,8 @@ from torch._prims_common import TensorLikeType from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType -from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL, - make_tensor_with_pad) +from vllm.utils import (STR_BACKEND_ENV_VAR, STR_ROCM_FLASH_ATTN_VAL, + STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -524,8 +524,13 @@ def make_backend(backend_name: str) -> AttentionBackend: if backend_name == STR_XFORMERS_ATTN_VAL: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend - return XFormersBackend() + + if backend_name == STR_ROCM_FLASH_ATTN_VAL: + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + raise AssertionError( f"Unrecognized backend_name {backend_name} for unit test") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b793ebf46d173..417dbc6d1483c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -86,6 +86,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): seq_lens: Optional[List[int]] # seq_lens stored as a tensor. seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| @@ -96,32 +107,38 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # |-- query_len ---| # Maximum query length in the batch. None for decoding. - max_query_len: Optional[int] - # Maximum sequence length among prefill batch. 0 if there are decoding - # requests only. - max_prefill_seq_len: int - # Maximum sequence length among decode batch. 0 if there are prefill - # requests only. - max_decode_seq_len: int + max_query_len: Optional[int] = None # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. - query_start_loc: Optional[torch.Tensor] + query_start_loc: Optional[torch.Tensor] = None # (batch_size + 1,). The cumulative sequence lengths of the sequences in # the batch, used to index into sequence. E.g., if the sequence length is # [4, 6], it is [0, 4, 10]. - seq_start_loc: Optional[torch.Tensor] - - # Whether or not if cuda graph is enabled. - # Cuda-graph is currently enabled for decoding only. - # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. - use_cuda_graph: bool + seq_start_loc: Optional[torch.Tensor] = None # (batch_size,) A tensor of context lengths (tokens that are computed # so far). - context_lens_tensor: Optional[torch.Tensor] + context_lens_tensor: Optional[torch.Tensor] = None _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + @property def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: if self.num_prefills == 0: @@ -132,10 +149,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: assert self.seq_lens is not None assert self.seq_lens_tensor is not None - assert self.query_start_loc is not None - assert self.context_lens_tensor is not None assert self.block_tables is not None - assert self.seq_start_loc is not None self._cached_prefill_metadata = ROCmFlashAttentionMetadata( num_prefills=self.num_prefills, @@ -147,12 +161,20 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, - query_start_loc=self.query_start_loc[:self.num_prefills + 1], - seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], - context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + query_start_loc=None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=None if self.seq_start_loc is None else + self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills], block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_prefill_metadata @property @@ -180,7 +202,12 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: context_lens_tensor=None, block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, - ) + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) return self._cached_decode_metadata def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", @@ -274,6 +301,97 @@ def _make_alibi_bias(alibi_slopes: torch.Tensor, return attn_biases +def _get_seq_len_block_table_args( + attn_metadata: ROCmFlashAttentionMetadata, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths + Encoder attn -> select encoder sequence lengths fields + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensors for query and key + * Appropriate max sequence-length scalar + ''' + + partial_prefix_sum = 0 + if attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + causal_mask = False + + # No block tables associated with encoder attention + return (query_seq_start_loc, attn_metadata.max_encoder_seq_len, + query_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.encoder_seq_lens, causal_mask) + elif attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + assert attn_metadata.seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + query_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + max_seq_len = attn_metadata.max_prefill_seq_len + causal_mask = True + + return (query_seq_start_loc, max_seq_len, query_seq_start_loc, + max_seq_len, attn_metadata.seq_lens, causal_mask) + elif attn_type == AttentionType.ENCODER_DECODER: + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens_tensor is not None + query_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.seq_lens + ], + device=attn_metadata.encoder_seq_lens_tensor.device, + dtype=attn_metadata.encoder_seq_lens_tensor.dtype) + + partial_prefix_sum = 0 + assert attn_metadata.encoder_seq_lens is not None + assert attn_metadata.seq_lens_tensor is not None + key_seq_start_loc = torch.tensor( + [0] + [ + partial_prefix_sum := partial_prefix_sum + i + for i in attn_metadata.encoder_seq_lens + ], + device=attn_metadata.seq_lens_tensor.device, + dtype=attn_metadata.seq_lens_tensor.dtype) + causal_mask = False + + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (query_start_loc, attn_metadata.max_prefill_seq_len, + key_seq_start_loc, attn_metadata.max_encoder_seq_len, + attn_metadata.seq_lens, causal_mask) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -391,64 +509,104 @@ def forward( ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * ROCmFlashAttentionImpl.forward() may be invoked for both self- and + cross-attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + Args: query: shape = [num_tokens, num_heads * head_size] key: shape = [num_tokens, num_kv_heads * head_size] value: shape = [num_tokens, num_kv_heads * head_size] kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") - - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None - if kv_cache is not None: + if attn_type != AttentionType.ENCODER and kv_cache is not None: key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - PagedAttention.write_to_paged_cache( - key, - value, - key_cache, - value_cache, - attn_metadata.slot_mapping, - self.kv_cache_dtype, - k_scale, - v_scale, - ) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens - assert value.shape[0] == num_prefill_tokens + num_decode_tokens - - output = torch.empty_like(query) + if key is not None and value is not None: + # Reshape the input keys and values and store them in the + # cache. If kv_cache is not provided, the new key and value + # tensors are not cached. This happens during the initial + # memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping + if attn_type != AttentionType.ENCODER_DECODER else + attn_metadata.cross_slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + if attn_type != AttentionType.ENCODER: + num_prefill_tokens = attn_metadata.num_prefill_tokens + else: + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] + # QKV for prefill. query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens + if key is not None and value is not None: + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: + output = torch.empty_like(query) + (query_seq_start_loc, query_max_seq_len, key_seq_start_loc, + key_max_seq_len, seq_lens, + causal_mask) = _get_seq_len_block_table_args( + prefill_meta, attn_type) + # Prompt run. - assert prefill_meta.seq_lens is not None if kv_cache is None or prefill_meta.block_tables.numel() == 0: # triton attention # When block_tables are not filled, it means q and k are the @@ -459,18 +617,18 @@ def forward( attn_masks = _make_alibi_bias( self.alibi_slopes, query.dtype, - attn_metadata.seq_lens, + seq_lens, make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, value, None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prefill_seq_len, - prefill_meta.max_prefill_seq_len, - True, + query_seq_start_loc, + key_seq_start_loc, + query_max_seq_len, + key_max_seq_len, + causal_mask, self.scale, attn_masks[0][None] if attn_masks is not None else None, @@ -494,11 +652,12 @@ def forward( query, key, value, - prefill_meta.seq_lens, - num_tokens, + query_seq_start_loc, + num_prefill_tokens, self.num_heads, self.head_size, self.scale, + causal_mask, attn_masks, ) else: @@ -506,10 +665,10 @@ def forward( q=query, k=key, v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, + cu_seqlens_q=query_seq_start_loc, + cu_seqlens_k=key_seq_start_loc, max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, + max_seqlen_k=key_max_seq_len, softmax_scale=self.scale, causal=True, window_size=self.sliding_window, @@ -545,6 +704,7 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. # Whether to use rocm custom paged attention or not + output = torch.empty_like(decode_query) num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads @@ -552,7 +712,10 @@ def forward( decode_query.dtype, head_size, block_size, gqa_ratio, decode_meta.max_decode_seq_len) if use_custom: - max_seq_len = decode_meta.max_decode_seq_len + max_seq_len = (decode_meta.max_decode_seq_len + if attn_type != AttentionType.ENCODER_DECODER + else decode_meta.max_encoder_seq_len) + assert max_seq_len is not None max_num_partitions = ( (max_seq_len + _PARTITION_SIZE_ROCM - 1) // _PARTITION_SIZE_ROCM) @@ -573,7 +736,7 @@ def forward( else: out = output ops.paged_attention_rocm( - out, + output[num_prefill_tokens:], exp_sums, max_logits, tmp_output, @@ -582,8 +745,12 @@ def forward( value_cache, self.num_kv_heads, self.scale, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, + decode_meta.block_tables + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, block_size, max_seq_len, self.alibi_slopes, @@ -596,9 +763,15 @@ def forward( decode_query, key_cache, value_cache, - decode_meta.block_tables, - decode_meta.seq_lens_tensor, - decode_meta.max_decode_seq_len, + decode_meta.block_tables + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.cross_block_tables, + decode_meta.seq_lens_tensor + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.encoder_seq_lens_tensor, + decode_meta.max_decode_seq_len + if attn_type != AttentionType.ENCODER_DECODER else + decode_meta.max_encoder_seq_len, self.kv_cache_dtype, self.num_kv_heads, self.scale, @@ -608,7 +781,7 @@ def forward( ) # Reshape the output tensor. - return output.view(num_tokens, hidden_size) + return output.view(-1, self.num_heads * self.head_size) def _sdpa_attention( @@ -620,6 +793,7 @@ def _sdpa_attention( num_heads: int, head_size: int, scale: float, + is_causal: bool, attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 @@ -637,7 +811,7 @@ def _sdpa_attention( key[:, start:end, :], value[:, start:end, :], dropout_p=0.0, - is_causal=attn_masks is None, + is_causal=is_causal, attn_mask=attn_masks[i] if attn_masks else None, scale=scale).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 7ea1d8d93ea2b..f765b8c39fa6c 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -84,8 +84,11 @@ def mm(self, inp, weights, bias=None): # uses this for linear units. However, sampler # will use torch.matmul with 2 dimensions only if inp.dim() == 3: - inp_view = inp.view(-1, inp.size(-1)) - batched = True + try: + inp_view = inp.view(-1, inp.size(-1)) + batched = True + except RuntimeError: + return F.linear(inp, weights, bias) else: inp_view = inp batched = False diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index bd716ac3e7ec3..4606866bdba52 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -23,7 +23,8 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) -from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_BACKEND, is_hip, + make_tensor_with_pad) from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata, @@ -120,7 +121,7 @@ def __init__( def _maybe_force_supported_attention_backend(self): ''' - Force vLLM to use the XFormers attention backend, + Force vLLM to use the XFormers or ROCM attention backend, which is currently the only supported option. ''' @@ -138,18 +139,21 @@ def raise_backend_err(): # The user has not already specified an attention backend # override logger.info("EncoderDecoderModelRunner requires " - "XFormers backend; overriding backend " - "auto-selection and forcing XFormers.") - global_force_attn_backend(_Backend.XFORMERS) + "XFormers or ROCM backend; overriding backend " + "auto-selection and forcing XFormers or ROCM.") + global_force_attn_backend( + _Backend.ROCM_FLASH if is_hip() else _Backend.XFORMERS) elif is_forced_by_global: # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. - if maybe_global_forced_backend != _Backend.XFORMERS: + if maybe_global_forced_backend != _Backend.XFORMERS and \ + maybe_global_forced_backend != _Backend.ROCM_FLASH: raise_backend_err() elif is_forced_by_env_var: # Backend override enforced by vLLM backend # environment variable - if maybe_env_var_forced_backend != _Backend.XFORMERS: + if maybe_env_var_forced_backend != _Backend.XFORMERS and \ + maybe_global_forced_backend != _Backend.ROCM_FLASH: raise_backend_err() def _list_to_int32_tensor(