-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding new rocm triton flash attention kernel #4
Conversation
Co-authored-by: Vinayak Gokhale <[email protected]>
|
||
encoded_softmax = None | ||
|
||
M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can remove this, and the associated writeback in the kernel for a minor speedup. This is used during the backward pass and we don't have a separate inference-only forward kernel. We don't need it for inference.
import torch | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.attention.ops.paged_attn import ( | ||
PagedAttentionImpl) | ||
from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some questions: can flash-attn-func and triton fa co-exist? or should we allow user to install them at the same time in docker file?
Also, what is the steps to validate this PR ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some questions: can flash-attn-func and triton fa co-exist? or should we allow user to install them at the same time in docker file?
Some background
flash_attn_func is Tri Dao's "reference" implementation of Flash Attention. It calls the cutlass version in the backend.
If is_hip()
, flash_attn_func calls the CK version in the backend.
Cutlass version I believe is a pip package and one can just pip install flash-attention
. I think for CK we need to install from our fork but I'm a little less familiar with that.
Triton is completely separate and does not intersect with either the front end flash_attn_func or cutlass / CK. It requires a Triton install. And the kernel, which is available in this PR.
So, yes both can co-exist but only one is needed.
window_size=self.sliding_window, | ||
alibi_slopes=self.alibi_slopes, | ||
) | ||
if is_hip(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a flag to skip
c582158
to
0e63661
Compare
Added Flag for controlling triton vs default flow. More small changes to dockerfile
mkdir -p libs \ | ||
&& cd libs \ | ||
&& pip uninstall -y triton \ | ||
&& git clone https://github.com/ROCmSoftwarePlatform/triton.git \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to use a specific branch for this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I know just the default branch should be fine
closing as we merge triton upstream |
Making this PR for a quick review before I open the main PR for upstream
@vgokhale Added you as a co-author since
vllm/model_executor/layers/attention/ops/flash_attention_triton.py
is basically all ours 😆