Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Adding flash attention for sequence parallel #565

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open

Conversation

dianaml0
Copy link
Contributor

@dianaml0 dianaml0 commented Dec 23, 2022

Patch Description
Creating this PR off of #511, so it can be reviewed by @stephenroller

The last commit (3d709db) removes some changes from the sequence parallel code which enabled testing with world size of 1. CI is not currently running the test anyway because CI needs to be updated for the test to run.

The forward and backward tests are passing right now. However in some cases, about .2% of the elements fail

Testing steps
Unit Test gpu_tests/test_sequence_parallel_transformer_layer.py

@dianaml0
Copy link
Contributor Author

CircleCI failure not related to this PR

@stephenroller
Copy link
Contributor

Can we rebase for checks? Should we be concerned about the last bits of numerical differences?

@dianaml0
Copy link
Contributor Author

dianaml0 commented Jan 3, 2023

@stephenroller just rebased the PR, should be up to date now. The rtol and atol used are the same ones we use for testing in xFormers for all flash attention bwds. I do a small training run to validate, would that be useful?

@dianaml0
Copy link
Contributor Author

dianaml0 commented Jan 4, 2023

Looks like everything is passing now after rebasing

@facebook-github-bot
Copy link

Hi @dianaml0!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants