diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index ecca5ab322..6ec2510675 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -68,6 +68,7 @@ def impl_test_self_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BS3HD, @@ -214,6 +215,7 @@ def test_cross_attn( batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( + is_training, dtype, dtype, QKVLayout.BSHD_BS2HD, @@ -345,6 +347,7 @@ def impl_test_context_parallel_attn( def check_has_backend_for_mask(mask_type): return is_fused_attn_kernel_available( + is_training, dtype, dtype, qkv_layout, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 745f1cc633..2332bbc0de 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -347,6 +347,7 @@ def _check_configs(self): ) self.backend = FusedAttnHelper( + self.is_training, self.dtype, self.dtype, self.qkv_layout, diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 6ce8637bc7..f91b9d4a52 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -229,6 +229,12 @@ def test(): "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1 "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), # inference + "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), # inference } @@ -270,12 +276,15 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + + is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, window_size=config.window_size, pad_between_seqs=pad_between_seqs, + is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends @@ -296,7 +305,6 @@ def test_dot_product_attention( if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") - is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 # UnfusedDotProductAttention backend if unfused_attn_supported: unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( @@ -1024,6 +1032,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type=config.attn_type, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: @@ -1367,6 +1377,8 @@ def _run_transformer_layer( bias=True, attn_input_format=qkv_format, ).to(dtype=dtype, device="cuda") + if not is_training: + block = block.eval() # Create ALiBi slopes alibi_slopes = None @@ -1384,8 +1396,9 @@ def _run_transformer_layer( core_attention_bias=bias, alibi_slopes=alibi_slopes, ) - loss = out.sum() - loss.backward() + if is_training: + loss = out.sum() + loss.backward() return out, inp.grad diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/fused_attn/test_kv_cache.py index eb3838ff12..6094c153fd 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/fused_attn/test_kv_cache.py @@ -52,7 +52,7 @@ 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 ), "infer_1": ModelConfig( - 2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 + 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 ), } @@ -370,12 +370,18 @@ def generate_args( ] -def get_tols(module, backend, dtype): +def get_tols(config, module, backend, dtype): if module == "TransformerLayer": - tols = { - torch.half: (5e-3, 5e-3), - torch.bfloat16: (3.5e-2, 3.5e-2), - } + if config.head_dim_qk <= 128: + tols = { + torch.half: (5e-3, 5e-3), + torch.bfloat16: (3.5e-2, 3.5e-2), + } + else: + tols = { + torch.half: (7e-3, 7e-3), + torch.bfloat16: (5e-2, 5e-2), + } if module == "DotProductAttention": tols = { torch.half: (1e-3, 1e-3), @@ -484,6 +490,16 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g # TransformerLayer FP8 TN Gemm currently requires %8=0 if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"): pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported") + if ( + backend == "FusedAttention" + and config.head_dim_qk > 128 + and not is_paged + and not is_cuda_graph + ): + pytest.skip( + "No support for KV caching with head dim > 128, non-paged attention, sq = 1, and mask" + " != no_mask" + ) # create full model logger.info("=== Generating all tokens at once ===") @@ -662,7 +678,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g incremental_output = incremental_output[0] # compare results - atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn) + atol, rtol = get_tols( + config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn + ) for i, seq in enumerate(sim.t_seq_ids): token_index = sim.step_lens[i] - 1 if qkv_format == "bshd": diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 25340dd879..80bd089ce3 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -216,12 +216,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } if ( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging - // special conditions for blackwell - // TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7 - !(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) && // architecture - ((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) || - (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) && + ((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) || + (cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) || + (cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) && // sequence length ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) || (cudnn_runtime_version >= 90000)) && @@ -229,11 +227,17 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) || (cudnn_runtime_version >= 8907)) && // head dimension - ((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) || - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // d=256 only supported for forward - (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 && - head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) && + (head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 && + ((head_dim_qk <= 128 && head_dim_v <= 128) || + (head_dim_qk <= 256 && head_dim_v <= 256 && + ((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) || + (is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) || + (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || + (!is_training && cudnn_runtime_version >= 91000 && + (max_seqlen_q > 1 || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || + (max_seqlen_q == 1 && sm_arch_ >= 100 && + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && @@ -423,8 +427,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -505,7 +509,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { @@ -636,8 +640,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -731,8 +735,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -862,8 +866,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -954,8 +958,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, + max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 2a2013358b..b0c5c2360c 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * + * \param[in] is_training Whether the model is in training mode. * \param[in] q_dtype The data type of Tensor Q. * \param[in] kv_dtype The data type of Tensors K, V. * \param[in] qkv_layout The layout of Tensors Q, K, V. @@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_right Sliding window size (the right half). */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( - NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); + bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 2c57d284de..d24e853e1c 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str): def is_fused_attn_kernel_available( + is_training, q_dtype, kv_dtype, qkv_layout, @@ -296,6 +297,7 @@ def is_fused_attn_kernel_available( def make_helper(attn_mask_type): return tex.FusedAttnHelper( + is_training, q_dtype, kv_dtype, qkv_layout, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ea682d4c47..c9490cb16b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -100,6 +100,7 @@ class FusedAttnHelper: Helper for the fused attention backend """ + is_training: bool q_dtype: jnp.dtype kv_dtype: jnp.dtype qkv_layout: QKVLayout @@ -120,6 +121,7 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( + self.is_training, jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), self.qkv_layout.value, @@ -273,6 +275,7 @@ def abstract( # backend determines the softmax buffer shape/dtype backend = FusedAttnHelper( + config.is_training, q_dtype, k_dtype, config.qkv_layout, diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index acc9c1dcb1..15fe6a4265 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -93,7 +93,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler); -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_num_heads, size_t kv_num_heads, diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index dd71a29d70..f919422bb2 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -10,7 +10,7 @@ namespace transformer_engine { namespace jax { -NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, +NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, size_t q_attn_heads, size_t kv_attn_heads, @@ -18,9 +18,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, size_t head_dim, int64_t window_size_left, int64_t window_size_right) { auto backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); return backend; } @@ -245,9 +245,9 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -498,9 +498,9 @@ static void FusedAttnBackwardImpl( NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim, head_dim, window_size_left, window_size_right); + is_training, static_cast(dtype), static_cast(dtype), qkv_layout, + bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1ac22a6d2f..e9a3047742 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -596,6 +596,8 @@ def __call__( seqlen_kv = key.shape[sequence_dim] has_fused_attn_kernel = is_fused_attn_kernel_available( + # This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode. + not deterministic, self.dtype, self.dtype, qkv_layout, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 82fc04a69a..1031ce8a5c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -766,6 +766,7 @@ def get_attention_backend( q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) kv_type = q_type fused_attention_backend = tex.get_fused_attn_backend( + is_training, q_type, kv_type, QKVLayout[qkv_layout], diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 361c24b22c..781fd865d5 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -35,13 +35,11 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T * Attention **************************************************************************************************/ -NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, - size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, - size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right); +NVTE_Fused_Attn_Backend get_fused_attn_backend( + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index efe825f0db..371bef2de9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -57,14 +57,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( - const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, - int64_t window_size_left, int64_t window_size_right) { + bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, + size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, + size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, - head_dim_qk, head_dim_v, window_size_left, window_size_right); + is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, + bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, + max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; }