diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 724f0ec8ce..f937055efc 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 724f0ec8ce06027feada51f2d948cd3313e63720 +Subproject commit f937055efc6d414d11f4c6577e3977fe74f35fb6 diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index adef8dd627..afb3a1df0c 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -68,6 +68,7 @@ def impl_test_self_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BS3HD, @@ -214,6 +215,7 @@ def test_cross_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BSHD_BS2HD, @@ -346,6 +348,7 @@ def impl_test_context_parallel_attn( def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( + is_training, dtype, dtype, qkv_layout, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 745f1cc633..2332bbc0de 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -347,6 +347,7 @@ def _check_configs(self): ) self.backend = FusedAttnHelper( + self.is_training, self.dtype, self.dtype, self.qkv_layout, diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index b82665911a..a05e64fca3 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -222,13 +222,19 @@ def test(): model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0 - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0 - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1 - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + # test: b, h, hg, d, sq, skv, p, mask, bias + "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), + "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), + "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), + "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), + "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), + "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), } @@ -270,14 +276,28 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + + is_training = True available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: + is_training = False + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention # mannually pads and unpads the input and output of FlashAttention for testing purposes @@ -296,7 +316,6 @@ def test_dot_product_attention( if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( @@ -360,6 +379,7 @@ def test_dot_product_attention( is_training, ) + logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols) @@ -399,18 +419,27 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_1_1": ModelConfig( 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 ), # cross, 0 + "mla_1_2": ModelConfig( + 4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # cross, 0 "mla_2_0": ModelConfig( 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 ), # self , 1 "mla_2_1": ModelConfig( 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 ), # cross, 1 + "mla_2_2": ModelConfig( + 1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + ), # cross, 1 "mla_3_0": ModelConfig( 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 ), # inference "mla_3_1": ModelConfig( 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 ), # inference + "mla_3_2": ModelConfig( + 8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 + ), # inference } @@ -1024,6 +1053,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type=config.attn_type, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: @@ -1136,14 +1167,29 @@ def test_transformer_layer( workspace_opt = True # Test backend availability + is_training = True available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=( qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") ), + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported: + is_training = False + available_backends, _, fused_attn_backends = _get_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=( + qkv_format.replace("hd", "h3d") + if fused_qkv_params + else qkv_format.replace("hd", "3hd") + ), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: @@ -1163,6 +1209,7 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) # FusedAttention backend @@ -1176,6 +1223,7 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) # FlashAttention backend @@ -1189,8 +1237,10 @@ def test_transformer_layer( workspace_opt, fused_qkv_params, RoPE, + is_training, ) + logging.info(f"[test_transformer_layer]: is_training = {is_training}") if unfused_attn_supported and fused_attn_supported: logging.info("[test_transformer_layer]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) @@ -1257,6 +1307,7 @@ def _run_transformer_layer( workspace_opt: bool, fused_qkv_params: bool, RoPE: bool, + is_training: bool, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run TransformerLayer module with one forward pass and one backward pass""" @@ -1410,6 +1461,8 @@ def _run_transformer_layer( bias=True, attn_input_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Create ALiBi slopes alibi_slopes = None @@ -1432,8 +1485,9 @@ def _run_transformer_layer( cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, ) - loss = out.sum() - loss.backward() + if is_training: + loss = out.sum() + loss.backward() return out, inp.grad diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/fused_attn/test_kv_cache.py index eb3838ff12..9673094597 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/fused_attn/test_kv_cache.py @@ -52,7 +52,7 @@ 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), "infer_1": ModelConfig( - 2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 + 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 ), } @@ -370,12 +370,24 @@ def generate_args( ] -def get_tols(module, backend, dtype): +def get_tols(config, module, backend, dtype): if module == "TransformerLayer": - tols = { - torch.half: (5e-3, 5e-3), - torch.bfloat16: (3.5e-2, 3.5e-2), - } + if config.head_dim_qk <= 128: + tols = { + torch.half: (5e-3, 5e-3), + torch.bfloat16: (3.5e-2, 3.5e-2), + } + else: + if backend == "UnfusedAttention": + tols = { + torch.half: (1.6e-2, 1.6e-2), + torch.bfloat16: (1.2e-1, 1e-1), + } + else: + tols = { + torch.half: (1e-2, 1e-2), + torch.bfloat16: (8e-2, 7e-2), + } if module == "DotProductAttention": tols = { torch.half: (1e-3, 1e-3), @@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g incremental_output = incremental_output[0] # compare results - atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) + atol, rtol = get_tols( + config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn + ) for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index a6784bacbb..b512133efd 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if ( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // special conditions for blackwell - // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture - ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || - (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && + ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || + (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || + (cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) && // sequence length ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || (cudnn_runtime_version >= 90000)) && @@ -229,11 +227,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && - head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && + // multiples of 8 + (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && + // <= 128 + ((head_dim_qk <= 128 && head_dim_v <= 128) || + // 9.1: <= 256 + Hopper + fprop + // 9.5: <= 256 + Hopper + bprop + (head_dim_qk <= 256 && head_dim_v <= 256 && + ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || + (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || + // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 + (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || + // 9.10: any head_dim + any arch + fprop + paged + // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91000 && + (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || + (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && + attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || + // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged + (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && + cudnn_runtime_version >= 91100))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && @@ -423,8 +438,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -505,7 +520,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { @@ -636,8 +651,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -731,8 +746,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -862,8 +877,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -954,8 +969,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index ebe8341cca..44f5791490 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * + * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V. @@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 2c57d284de..d24e853e1c 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def is_fused_attn_kernel_available( + is_training, q_dtype, kv_dtype, qkv_layout, @@ -296,6 +297,7 @@ def is_fused_attn_kernel_available( def make_helper(attn_mask_type): return tex.FusedAttnHelper( + is_training, q_dtype, kv_dtype, qkv_layout, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6b617355a3..e8907eb127 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -103,6 +103,7 @@ class FusedAttnHelper: Helper for the fused attention backend """ + is_training: bool q_dtype: jnp.dtype kv_dtype: jnp.dtype qkv_layout: QKVLayout @@ -123,6 +124,7 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( + self.is_training, jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), self.qkv_layout.value, @@ -276,6 +278,7 @@ def abstract( # backend determines the softmax buffer shape/dtype backend = FusedAttnHelper( + config.is_training, q_dtype, k_dtype, config.qkv_layout, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index aa257abe95..47399bc791 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -96,7 +96,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_num_heads, size_t kv_num_heads, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 7235d3f232..d3bb845642 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -11,7 +11,7 @@ namespace transformer_engine { namespace jax { -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, @@ -19,9 +19,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { auto backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); return backend; } @@ -263,9 +263,9 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -518,9 +518,9 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1ac22a6d2f..e9a3047742 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -596,6 +596,8 @@ def __call__( seqlen_kv = key.shape[sequence_dim] has_fused_attn_kernel = is_fused_attn_kernel_available( + # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. + not deterministic, self.dtype, self.dtype, qkv_layout, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ec93e8c5c8..d98dde0159 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -761,6 +761,7 @@ def get_attention_backend( q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( + is_training, q_type, kv_type, QKVLayout[qkv_layout], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 361c24b22c..781fd865d5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -35,13 +35,11 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend( + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 9d6a99e6a9..ee4f7cdcc6 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -57,14 +57,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( - const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim_qk, head_dim_v, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; }