Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new rocm triton flash attention kernel #4

Closed
wants to merge 8 commits into from

Conversation

jpvillam-amd
Copy link

Making this PR for a quick review before I open the main PR for upstream

@vgokhale Added you as a co-author since vllm/model_executor/layers/attention/ops/flash_attention_triton.py is basically all ours 😆


encoded_softmax = None

M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove this, and the associated writeback in the kernel for a minor speedup. This is used during the backward pass and we don't have a separate inference-only forward kernel. We don't need it for inference.

import torch

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions: can flash-attn-func and triton fa co-exist? or should we allow user to install them at the same time in docker file?
Also, what is the steps to validate this PR ?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some questions: can flash-attn-func and triton fa co-exist? or should we allow user to install them at the same time in docker file?

Some background

flash_attn_func is Tri Dao's "reference" implementation of Flash Attention. It calls the cutlass version in the backend.

If is_hip(), flash_attn_func calls the CK version in the backend.

Cutlass version I believe is a pip package and one can just pip install flash-attention. I think for CK we need to install from our fork but I'm a little less familiar with that.

Triton is completely separate and does not intersect with either the front end flash_attn_func or cutlass / CK. It requires a Triton install. And the kernel, which is available in this PR.

So, yes both can co-exist but only one is needed.

Dockerfile.rocm Outdated Show resolved Hide resolved
Dockerfile.rocm Outdated Show resolved Hide resolved
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if is_hip():
Copy link
Author

@jpvillam-amd jpvillam-amd Mar 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a flag to skip

Added Flag for controlling triton vs default flow.
More small changes to dockerfile
mkdir -p libs \
&& cd libs \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCmSoftwarePlatform/triton.git \

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use a specific branch for this PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know just the default branch should be fine

@shajrawi
Copy link
Collaborator

shajrawi commented May 3, 2024

closing as we merge triton upstream

@shajrawi shajrawi closed this May 3, 2024
gshtras pushed a commit that referenced this pull request Sep 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants