Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new rocm triton flash attention kernel #4

Closed
wants to merge 8 commits into from
14 changes: 14 additions & 0 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ ARG BUILD_FA="1"
# whether to build cupy on rocm
ARG BUILD_CUPY="1"

# whether to build triton on rocm
ARG BUILD_TRITON="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -95,6 +98,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \
&& cd ..; \
fi

# build triton
RUN if [ "$BUILD_TRITON" = "1"]; then \
jpvillam-amd marked this conversation as resolved.
Show resolved Hide resolved
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCmSoftwarePlatform/triton.git
jpvillam-amd marked this conversation as resolved.
Show resolved Hide resolved
&& cd triton/python \
&& pip3 install -e . \
&& cd ../..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
if (torch.cuda.get_device_capability()[0] >= 8 and
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
Expand Down
40 changes: 30 additions & 10 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,21 @@
from typing import List, Optional

# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from flash_attn import flash_attn_func
from vllm.utils import is_hip
try:
from flash_attn import flash_attn_func
except ImportError:
if is_hip():
pass
else:
raise

import torch

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions: can flash-attn-func and triton fa co-exist? or should we allow user to install them at the same time in docker file?
Also, what is the steps to validate this PR ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions: can flash-attn-func and triton fa co-exist? or should we allow user to install them at the same time in docker file?

Some background

flash_attn_func is Tri Dao's "reference" implementation of Flash Attention. It calls the cutlass version in the backend.

If is_hip(), flash_attn_func calls the CK version in the backend.

Cutlass version I believe is a pip package and one can just pip install flash-attention. I think for CK we need to install from our fork but I'm a little less familiar with that.

Triton is completely separate and does not intersect with either the front end flash_attn_func or cutlass / CK. It requires a Triton install. And the kernel, which is available in this PR.

So, yes both can co-exist but only one is needed.



class FlashAttentionBackend:
Expand Down Expand Up @@ -86,15 +95,26 @@ def forward(
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if is_hip():
Copy link
Author

@jpvillam-amd jpvillam-amd Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a flag to skip

output, _ = attention(
query,
key,
value,
None,
input_metadata,
True,
self.scale,
)
else:
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
Expand Down
Loading