Skip to content

Commit

Permalink
llama3.2 + cross attn test
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Malyshev committed Oct 3, 2024
1 parent 030374b commit 64cdd32
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 91 deletions.
5 changes: 3 additions & 2 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from vllm.utils import is_hip

# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] if not is_hip() \
else [_Backend.ROCM_FLASH]

HEAD_SIZES = [64, 256]

Expand Down Expand Up @@ -807,7 +808,7 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)


@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
# @pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
Expand Down
9 changes: 7 additions & 2 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
STR_ROCM_FLASH_ATTN_VAL, make_tensor_with_pad)

# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
Expand Down Expand Up @@ -524,8 +524,13 @@ def make_backend(backend_name: str) -> AttentionBackend:
if backend_name == STR_XFORMERS_ATTN_VAL:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from vllm.attention.backends.xformers import XFormersBackend

return XFormersBackend()

if backend_name == STR_ROCM_FLASH_ATTN_VAL:
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend

raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")

Expand Down
Loading

0 comments on commit 64cdd32

Please sign in to comment.