Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't save inputs/outputs buffer of FlashAttenFunc to reduce memory usage for inference mode #1383

Merged
merged 1 commit into from
Dec 12, 2024

Conversation

XiaobingSuper
Copy link
Contributor

In inference mode, we don't need to save the inputs/outputs used by the training mode. This PR can reduce memory usage during LLM serving(such as vllm, it uses those flash attn APIs for a better performance).

@tridao tridao merged commit 0dfb281 into Dao-AILab:main Dec 12, 2024
@rocking5566
Copy link
Contributor

rocking5566 commented Dec 17, 2024

@XiaobingSuper
I found this PR make pytest (https://github.com/Dao-AILab/flash-attention/tree/main/tests) fail.

is_grad = torch.is_grad_enabled() and qkv.requires_grad

torch.is_grad_enabled() always be False in the test script, which make backward cannot access ctx.saved_tensors

@XiaobingSuper
Copy link
Contributor Author

XiaobingSuper commented Dec 19, 2024

@XiaobingSuper I found this PR make pytest (https://github.com/Dao-AILab/flash-attention/tree/main/tests) fail.

is_grad = torch.is_grad_enabled() and qkv.requires_grad

torch.is_grad_enabled() always be False in the test script, which make backward cannot access ctx.saved_tensors

Sorry, it is my mistake, gradient computation is already disabled in custom autograd.Functions by default(https://discuss.pytorch.org/t/is-torch-no-grad-making-a-difference-in-custom-autograd-functions/186627), it need to check it before calling the customer op.

@XiaobingSuper
Copy link
Contributor Author

@rocking5566 I created a PR #1397 to fix this issue. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants