-
Notifications
You must be signed in to change notification settings - Fork 19.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash attention support. #20152
Flash attention support. #20152
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20152 +/- ##
==========================================
+ Coverage 78.81% 78.85% +0.04%
==========================================
Files 512 513 +1
Lines 49063 49250 +187
Branches 9035 9080 +45
==========================================
+ Hits 38668 38837 +169
- Misses 8530 8543 +13
- Partials 1865 1870 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR -- the code looks good! Please add a unit test.
For the JAX version, I think we'd want to rely on a Pallas kernel. We can get help from the JAX team.
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests? |
In |
@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)? |
It should be possible to consolidate this into As far as I know, for torch, flash attention is utilized if the conditions are met. For jax, we need to specify |
Very cool -- @hazemessamm can we do that, e.g. by adding a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Thank you.
Thank you, glad I could help. |
The test fails on torch + GPU:
Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on? |
I think flash attention in PyTorch does not work with any dtype except float16 and on specific GPUs, I just tested it on H100 GPU and it worked fine but it did not work on T4 GPU on Colab. I also just found the following functions in PyTorch that we can use to check whether the inputs and the current GPU can use flash attention or not. import torch
bsz, num_heads, seqlen, head_dim = 1, 2, 10, 16
query = torch.randn((bsz, num_heads, seqlen, head_dim), dtype=torch.float32, device='cuda:0')
params = torch.backends.cuda.SDPAParams(query, query, query, None, 16**-0.5, False)
is_flash_attention_enabled = torch.backends.cuda.can_use_flash_attention(params, False)
print(is_flash_attention_enabled) # Output: False, it will be true if `dtype=torch.float16` If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend. Documentation: |
That sounds great! Then, we can also skip the PyTorch unit test when this check evaluates to False. |
…ion and removed flash attention from tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thank you! Can you also add the test back? You can use pytest.mark.skipif
to skip when unimplemented for PyTorch for TF.
I skipped the tests for TensorFlow, NumPy and torch and I just tested JAX on T4 GPU on colab and I got this error: |
I added some conditions for JAX to skip the tests if they were met, what do you think? |
I added support for flash attention for PyTorch.
Let me know what do you think about this current implementation so I can add support for JAX and maybe will try for TF.