Skip to content

Commit

Permalink
llama3.2 + cross attn test (#220)
Browse files Browse the repository at this point in the history
* llama3.2 + cross attn test

* lint issues fix

* mypy errors

* making yapf happy

* cut off WA for tunned gemms

* try and catch for non continuous tensor

---------

Co-authored-by: Aleksandr Malyshev <[email protected]>
  • Loading branch information
maleksan85 and Aleksandr Malyshev authored Oct 4, 2024
1 parent 4075b35 commit 2550f14
Show file tree
Hide file tree
Showing 5 changed files with 280 additions and 94 deletions.
4 changes: 2 additions & 2 deletions tests/kernels/test_encoder_decoder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from vllm.utils import is_hip

# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS]
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS] if not is_hip() \
else [_Backend.ROCM_FLASH]

HEAD_SIZES = [64, 256]

Expand Down Expand Up @@ -807,7 +808,6 @@ def test_encoder_only(
assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out)


@pytest.mark.skipif(is_hip(), reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS)
Expand Down
11 changes: 8 additions & 3 deletions tests/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch._prims_common import TensorLikeType

from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_XFORMERS_ATTN_VAL,
make_tensor_with_pad)
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_ROCM_FLASH_ATTN_VAL,
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)

# For now, disable "test_aot_dispatch_dynamic" since there are some
# bugs related to this test in PyTorch 2.4.
Expand Down Expand Up @@ -524,8 +524,13 @@ def make_backend(backend_name: str) -> AttentionBackend:
if backend_name == STR_XFORMERS_ATTN_VAL:
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
from vllm.attention.backends.xformers import XFormersBackend

return XFormersBackend()

if backend_name == STR_ROCM_FLASH_ATTN_VAL:
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend

raise AssertionError(
f"Unrecognized backend_name {backend_name} for unit test")

Expand Down
Loading

0 comments on commit 2550f14

Please sign in to comment.