You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On CUDA 12.5 and 12.6, I see warnings that NVCC was not able to correctly analyze register usage around WGMMA, and inserts unnecessary warpgroup waits:
ptxas info : (C7517) warpgroup.wait is injected in around line 13239 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info : (C7517) warpgroup.wait is injected in around line 14386 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info : (C7517) warpgroup.wait is injected in around line 14959 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info : (C7517) warpgroup.wait is injected in around line 16120 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
ptxas info : (C7517) warpgroup.wait is injected in around line 16720 by compiler to allow use of registers defined by GMMA in function '_ZN5flash15compute_attn_wsI23Flash_fwd_kernel_traitsILi64ELi192ELi128ELi16ELi2ELb0ELi1EN7cutlass10bfloat16_tEELb1ELb0ENS_30DynamicPersistentTileSchedulerILi384ELi32EEENS_12SeqLenTraitsILb0EEEEEvNS_21CollectiveMainloopFwdIT_XT0_EXT1_ET3_E6ParamsENS_21CollectiveEpilogueFwdISA_SB_E6ParamsENT2_6ParamsESB_SB_'
I have reason to believe that this is a bug in NVCC, as an independent implementation of FA3 I have been working on suffers the same problem on later NVCC versions. I don't have a suggestion on how to convince NVCC to not get confused.
I can confirm that I get the expected performance on CUDA 12.3.
The text was updated successfully, but these errors were encountered:
On CUDA 12.5 and 12.6, I see warnings that NVCC was not able to correctly analyze register usage around WGMMA, and inserts unnecessary warpgroup waits:
and results worse than expected performance:
I have reason to believe that this is a bug in NVCC, as an independent implementation of FA3 I have been working on suffers the same problem on later NVCC versions. I don't have a suggestion on how to convince NVCC to not get confused.
I can confirm that I get the expected performance on CUDA 12.3.
The text was updated successfully, but these errors were encountered: