Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention support. #20152

Merged
merged 13 commits into from
Oct 8, 2024
Merged

Conversation

hazemessamm
Copy link
Contributor

I added support for flash attention for PyTorch.

Let me know what do you think about this current implementation so I can add support for JAX and maybe will try for TF.

@codecov-commenter
Copy link

codecov-commenter commented Aug 22, 2024

Codecov Report

Attention: Patch coverage is 26.31579% with 14 lines in your changes missing coverage. Please review.

Project coverage is 78.85%. Comparing base (5aa5f88) to head (57e6e56).
Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/nn.py 18.18% 8 Missing and 1 partial ⚠️
keras/src/backend/numpy/nn.py 0.00% 1 Missing and 1 partial ⚠️
keras/src/backend/tensorflow/nn.py 0.00% 1 Missing and 1 partial ⚠️
keras/src/backend/jax/nn.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20152      +/-   ##
==========================================
+ Coverage   78.81%   78.85%   +0.04%     
==========================================
  Files         512      513       +1     
  Lines       49063    49250     +187     
  Branches     9035     9080      +45     
==========================================
+ Hits        38668    38837     +169     
- Misses       8530     8543      +13     
- Partials     1865     1870       +5     
Flag Coverage Δ
keras 78.71% <26.31%> (+0.04%) ⬆️
keras-jax 62.36% <21.05%> (+0.10%) ⬆️
keras-numpy 57.38% <10.52%> (-0.03%) ⬇️
keras-tensorflow 63.62% <10.52%> (+0.06%) ⬆️
keras-torch 62.35% <15.78%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR -- the code looks good! Please add a unit test.

For the JAX version, I think we'd want to rely on a Pallas kernel. We can get help from the JAX team.

Copy link

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale label Sep 24, 2024
@hazemessamm
Copy link
Contributor Author

Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests?

@fchollet
Copy link
Member

fchollet commented Oct 2, 2024

Hey, sorry for not finishing this PR, I have a quick question, where should I add the tests?

In keras/src/ops/nn_test.py. Ops are tested through the op class in e.g. keras/src/ops/nn.py, rather than in a backend specific way.

@fchollet
Copy link
Member

fchollet commented Oct 5, 2024

@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?

@james77777778
Copy link
Contributor

@james77777778 do you think flash attention should be a standalone op, or could this be managed at the level of the dot_product_attention op (e.g. as an argument)?

It should be possible to consolidate this into dot_product_attention. That’s how it's implemented in torch, and I've seen a similar approach in jax
(https://github.com/jax-ml/jax/blob/81a31f6adf453b2afc39936e15c15d8ad327bf6e/jax/_src/nn/functions.py#L1037-L1041)

As far as I know, for torch, flash attention is utilized if the conditions are met. For jax, we need to specify implementation="cudnn" to use it.

@fchollet
Copy link
Member

fchollet commented Oct 6, 2024

Very cool -- @hazemessamm can we do that, e.g. by adding a flash_attention argument in dot_product_attention? This makes it quite easy to also add support for JAX ( in addition to PyTorch). For TF I think we can skip support for now.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Thank you.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 6, 2024
@hazemessamm
Copy link
Contributor Author

Awesome work! Thank you.

Thank you, glad I could help.

@fchollet
Copy link
Member

fchollet commented Oct 6, 2024

The test fails on torch + GPU:

FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution.

Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on?

@hazemessamm
Copy link
Contributor Author

hazemessamm commented Oct 6, 2024

The test fails on torch + GPU:

FAILED keras/src/ops/nn_test.py::NNOpsCorrectnessTest::test_dot_product_attention_none_none_(true, false)_true - RuntimeError: No available kernel. Aborting execution.

Do you know if this is an issue with the torch version? What version is required? What torch + GPU setup were you testing on?

I think flash attention in PyTorch does not work with any dtype except float16 and on specific GPUs, I just tested it on H100 GPU and it worked fine but it did not work on T4 GPU on Colab.

I also just found the following functions in PyTorch that we can use to check whether the inputs and the current GPU can use flash attention or not.

import torch
bsz, num_heads, seqlen, head_dim = 1, 2, 10, 16
query = torch.randn((bsz, num_heads, seqlen, head_dim), dtype=torch.float32, device='cuda:0')

params = torch.backends.cuda.SDPAParams(query, query, query, None, 16**-0.5, False)
is_flash_attention_enabled = torch.backends.cuda.can_use_flash_attention(params, False)
print(is_flash_attention_enabled) # Output: False, it will be true if `dtype=torch.float16`

If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend.

Documentation:
https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.SDPAParams
https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.can_use_flash_attention

@fchollet
Copy link
Member

fchollet commented Oct 7, 2024

If you think that this is a good idea then I will use this snippet in the flash attention function in PyTorch backend.

That sounds great! Then, we can also skip the PyTorch unit test when this check evaluates to False.

@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Oct 7, 2024
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thank you! Can you also add the test back? You can use pytest.mark.skipif to skip when unimplemented for PyTorch for TF.

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 7, 2024
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Oct 7, 2024
@hazemessamm
Copy link
Contributor Author

hazemessamm commented Oct 7, 2024

I skipped the tests for TensorFlow, NumPy and torch and I just tested JAX on T4 GPU on colab and I got this error: RuntimeError: Require at least Ampere arch to run, so we will need JAX + GPU tests to run on Ampere arch otherwise we will need to skip the tests for all frameworks. Also the current JAX version that runs on github tests does not have dot_product_attention function.

@hazemessamm
Copy link
Contributor Author

hazemessamm commented Oct 7, 2024

I added some conditions for JAX to skip the tests if they were met, what do you think?

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 7, 2024
@fchollet fchollet merged commit 8e67e0e into keras-team:master Oct 8, 2024
9 checks passed
@google-ml-butler google-ml-butler bot removed the ready to pull Ready to be merged into the codebase label Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

6 participants