Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

atomic_add slows down attention backwards due to layout conversions #4717

Open
bertmaher opened this issue Sep 12, 2024 · 7 comments
Open

atomic_add slows down attention backwards due to layout conversions #4717

bertmaher opened this issue Sep 12, 2024 · 7 comments
Assignees

Comments

@bertmaher
Copy link
Collaborator

@Chillee noticed that using atomic_add in the backward of attention notably slows down the kernel, and in fact it's slower than "manually" doing atomic_add using inline assembly. The root cause seems to be that the layout conversion from #mma layout to #blocked adds a lot of overhead; interestingly, using tl.store (which is incorrect) does the layout conversion but is nevertheless faster than the atomic_add version (possibly due to async copying from smem to gmem).

Repro at https://gist.github.com/bertmaher/e33b874f75cb82451060b88ee20b8203.

Results on my H100:

$ python fa.py atomic_add
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0      76.443623
1   4096.0      96.451661
2   8192.0     118.992739
3  16384.0     128.130487
4  32768.0     129.460807

$ python fa.py none  # entirely omit the atomic_add
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     169.550013
1   4096.0     232.080104
2   8192.0     280.737055
3  16384.0     314.714893
4  32768.0     324.544066

$ python fa.py inline_asm  # Do the atomic_add using inline asm; avoids layout conversion
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     138.047146
1   4096.0     178.601803
2   8192.0     208.302319
3  16384.0     225.736351
4  32768.0     220.597077

$ python fa.py store  # Just do store; does layout conversion but is faster than atomic_add
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     154.972883
1   4096.0     198.467770
2   8192.0     228.074767
3  16384.0     245.977935
4  32768.0     257.773992
lijinpei added a commit to lijinpei/triton that referenced this issue Sep 22, 2024
@lijinpei
Copy link
Contributor

Had a fix for this issue, got similar perf metrics on my machine, maybe you can try if it works for you on you machine. lijinpei@1e344f7

(triton) ➜  issue_4717 python3 fa.py atomic_add
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0      44.369918
1   4096.0      55.343324
2   8192.0      59.959515
3  16384.0      63.016624
4  32768.0      64.696492
(triton) ➜  issue_4717 python3 fa.py inline_asm                         
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0      44.053087
1   4096.0      54.871818
2   8192.0      59.553212
3  16384.0      62.199749
4  32768.0      63.910920

@Jokeren Jokeren self-assigned this Sep 22, 2024
@bertmaher
Copy link
Collaborator Author

Thanks for the suggestion @lijinpei! I am curious though if that is a generally desirable change. I would have expected that converting from mma to blocked layout would generally improve perf due to better memory coalescing. @Jokeren maybe you have some insight here?

@Jokeren
Copy link
Contributor

Jokeren commented Sep 26, 2024

blocked layout would generally improve perf due to better memory coalescing.

Storing data with mma is still buggy in triton and is something I'll be working on after all pending linear layout PRs have been merged. It's not practical because of the problem you mentioned.

One reason that atomic_add is still slow is probably because we haven't supported vector data types yet. This is something that @lijinpei can look into if he has time

@davidberard98
Copy link
Contributor

I'll take a look to see if we can get memory coalescing to happen and/or use vectorized loads.

@davidberard98
Copy link
Contributor

#4971 for generating vectorized atomic_add instructions

@davidberard98
Copy link
Contributor

@bertmaher these are the numbers I get after adding vectorized atomics. I'm guessing this means that the vectorized atomics support is sufficient - let me know if you think it's worth investigating further though!

$ python fa.py atomic_add
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     148.479322
1   4096.0     191.492399
2   8192.0     219.929005
3  16384.0     238.591647
4  32768.0     242.346481
$ python fa.py none
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     166.910448
1   4096.0     230.168433
2   8192.0     280.479002
3  16384.0     313.807679
4  32768.0     316.903208
$ python fa.py inline_asm
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     136.014504
1   4096.0     177.884463
2   8192.0     208.574426
3  16384.0     224.902280
4  32768.0     222.324913
$ python fa.py store
fused-attention-batch4-head8-d64-bwd-causal=True:
     N_CTX  Triton [FP16]
0   2048.0     148.669622
1   4096.0     198.080360
2   8192.0     227.895409
3  16384.0     246.092478
4  32768.0     257.209955

@Jokeren
Copy link
Contributor

Jokeren commented Oct 22, 2024

@davidberard98 great!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants