Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new rocm triton flash attention kernel #4

Closed
wants to merge 8 commits into from
14 changes: 14 additions & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ ARG BUILD_FA="1"
# whether to build cupy on rocm
ARG BUILD_CUPY="1"

# whether to build triton on rocm
ARG BUILD_TRITON="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -95,6 +98,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \
&& cd ..; \
fi

# build triton
RUN if [ "$BUILD_TRITON" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCmSoftwarePlatform/triton.git \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use a specific branch for this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know just the default branch should be fine

&& cd triton/python \
&& pip3 install . \
&& cd ../..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
Expand Down
47 changes: 30 additions & 17 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
import os

logger = init_logger(__name__)

Expand All @@ -34,11 +35,12 @@ def __init__(
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if _use_flash_attn():
if use_triton := _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
sliding_window,
use_triton == 2)
else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale,
Expand All @@ -59,26 +61,37 @@ def forward(


@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False

if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
def _use_flash_attn() -> int:
"""Returns if and which flash attention to use.

Returns:
int: 0 for none, 1 for default implementation, 2 for triton implementation.
"""
if not (os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') and is_hip()):
# AMD GPUs can use flash_attn package or triton impl.
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return 0

if (not is_hip()) and torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
return 0

if is_hip() and torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs. "
"Using xformers backend.")
return 0

if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False
return 0

logger.info("Using flash_attn backend.")
return True
logger.info(f"Using {'Triton' if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else ''} flash_attn backend.")
return 2 if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else 1
43 changes: 34 additions & 9 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional

from vllm.utils import is_hip
from flash_attn import flash_attn_func
import torch

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.model_executor.layers.attention.ops.flash_attention_triton import triton_attention


class FlashAttentionBackend:
Expand All @@ -19,6 +21,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
use_triton: Optional[bool] = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -28,6 +31,7 @@ def __init__(
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.use_triton = use_triton

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand Down Expand Up @@ -85,15 +89,36 @@ def forward(
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if self.use_triton:
output, _ = triton_attention(
query,
key,
value,
None,
input_metadata,
True,
self.scale,
)
else:
if is_hip():
#XXX: window_size and alibi_slopes not supported
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
)
else:
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
Expand Down
Loading
Loading