-
Notifications
You must be signed in to change notification settings - Fork 2.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
adding flash attention and xformer memory efficient through PT SDPA #97
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HamidShojanazeri Thanks for adding the BT optimizations. Please see the comments inline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the issues and also attach the logs for the inference speedup.
What does this PR do?
This PR adds the Flash Attention and Xformer mem-efficient kernel through PT SDPA, this work has been integrated with
optimum
library of HF, read more about here.Tested on 7B for FSDP only had a nice 30% speed up, for FSDP+PEFT 5% and not much on PEFT+quantization/1 gpu.
Fixes # (issue)
No related issue.
Feature/Issue validation/testing
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
Test A : Logs/ perf number of with out this feature with 10 steps : avg epoch time 55.17s
Logs for Test A
Test B : Logs/ perf number of with this feature with 10 steps : avg epoch time 42.44
Logs for Test B
Before submitting
Pull Request section?
to it if that's the case.
Thanks for contributing 🎉!