Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA versions > 12.3 do not correctly compile H100 Flash Attention 3 #1243

Open
rohany opened this issue Sep 21, 2024 · 1 comment
Open

CUDA versions > 12.3 do not correctly compile H100 Flash Attention 3 #1243

rohany opened this issue Sep 21, 2024 · 1 comment

Comments

@rohany
Copy link

rohany commented Sep 21, 2024

On CUDA 12.5 and 12.6, I see warnings that NVCC was not able to correctly analyze register usage around WGMMA, and inserts unnecessary warpgroup waits:

ptxas info    : (C7517) warpgroup.wait is injected in around line 13239 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info    : (C7517) warpgroup.wait is injected in around line 14386 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info    : (C7517) warpgroup.wait is injected in around line 14959 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info    : (C7517) warpgroup.wait is injected in around line 16120 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info    : (C7517) warpgroup.wait is injected in around line 16720 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'

and results worse than expected performance:

(cypress) root@eos0266:/opt/flash-attention/hopper# python benchmark_attn.py
### mode = 'fwd', batch_size = 4, headdim = 128, seqlen = 2048, causal = False ###
Fav3: 0.515ms, 533.9 TFLOPS
### mode = 'fwd', batch_size = 4, headdim = 128, seqlen = 4096, causal = False ###
Fav3: 2.023ms, 543.6 TFLOPS
### mode = 'fwd', batch_size = 4, headdim = 128, seqlen = 8192, causal = False ###
Fav3: 7.996ms, 550.0 TFLOPS
### mode = 'fwd', batch_size = 4, headdim = 128, seqlen = 16384, causal = False ###
Fav3: 32.211ms, 546.1 TFLOPS

I have reason to believe that this is a bug in NVCC, as an independent implementation of FA3 I have been working on suffers the same problem on later NVCC versions. I don't have a suggestion on how to convince NVCC to not get confused.

I can confirm that I get the expected performance on CUDA 12.3.

@tridao
Copy link
Contributor

tridao commented Sep 21, 2024

Yes we're seeing the best performance on CUDA 12.3. There might be some fix for 12.5 (by better tuning) but we're not there yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants