@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
134
134
135
135
// select a backend for fused attention
136
136
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend (
137
- bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
138
- NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups ,
139
- size_t max_seqlen_q , size_t max_seqlen_kv , size_t head_dim_qk , size_t head_dim_v ,
140
- int64_t window_size_left, int64_t window_size_right) {
137
+ bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
138
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
139
+ size_t num_gqa_groups , size_t max_seqlen_q , size_t max_seqlen_kv , size_t head_dim_qk ,
140
+ size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
141
141
using namespace transformer_engine ;
142
142
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
143
143
const int device_id = cuda::current_device ();
@@ -228,16 +228,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
228
228
(cudnn_runtime_version >= 8907 )) &&
229
229
// head dimension
230
230
(head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
231
- ((head_dim_qk <= 128 && head_dim_v <= 128 ) ||
232
- (head_dim_qk <= 256 && head_dim_v <= 256 &&
233
- ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100 ) ||
234
- (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500 ))) ||
235
- (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 &&
236
- max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
237
- (!is_training && cudnn_runtime_version >= 91000 &&
238
- (max_seqlen_q > 1 ||
239
- layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD ||
240
- (max_seqlen_q == 1 && sm_arch_ >= 100 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) &&
231
+ ((head_dim_qk <= 128 && head_dim_v <= 128 ) ||
232
+ (head_dim_qk <= 256 && head_dim_v <= 256 &&
233
+ ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100 ) ||
234
+ (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500 ))) ||
235
+ (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
236
+ layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
237
+ (!is_training && cudnn_runtime_version >= 91000 &&
238
+ (max_seqlen_q > 1 || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD ||
239
+ (max_seqlen_q == 1 && sm_arch_ >= 100 &&
240
+ attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) &&
241
241
// bias type
242
242
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
243
243
(cudnn_runtime_version >= 8906 &&
@@ -427,8 +427,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
427
427
const NVTEDType QKV_type = static_cast <NVTEDType>(input_QKV->data .dtype );
428
428
429
429
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
430
- is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
431
- max_seqlen, d, d, window_size_left, window_size_right);
430
+ is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
431
+ max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
432
432
433
433
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
434
434
#if (CUDNN_VERSION >= 8901)
@@ -640,8 +640,8 @@ void nvte_fused_attn_fwd_kvpacked(
640
640
const NVTEDType KV_type = static_cast <NVTEDType>(input_KV->data .dtype );
641
641
642
642
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
643
- is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
644
- max_seqlen_kv, d, d, window_size_left, window_size_right);
643
+ is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
644
+ max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
645
645
646
646
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
647
647
#if (CUDNN_VERSION >= 8901)
@@ -735,8 +735,8 @@ void nvte_fused_attn_bwd_kvpacked(
735
735
const NVTEDType KV_type = static_cast <NVTEDType>(input_KV->data .dtype );
736
736
737
737
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
738
- true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
739
- max_seqlen_kv, d, d, window_size_left, window_size_right);
738
+ true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
739
+ max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
740
740
741
741
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
742
742
#if (CUDNN_VERSION >= 8901)
@@ -866,8 +866,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
866
866
const NVTEDType KV_type = static_cast <NVTEDType>(input_K->data .dtype );
867
867
868
868
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
869
- is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
870
- max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
869
+ is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
870
+ max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
871
871
872
872
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
873
873
#if (CUDNN_VERSION >= 8901)
@@ -958,8 +958,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
958
958
const NVTEDType KV_type = static_cast <NVTEDType>(input_K->data .dtype );
959
959
960
960
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend (
961
- true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
962
- max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
961
+ true , Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
962
+ max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
963
963
964
964
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
965
965
#if (CUDNN_VERSION >= 8901)
0 commit comments