Skip to content

Commit 494944e

Browse files
committed
add support for head dim > 128
Signed-off-by: Charlene Yang <[email protected]>
1 parent 1d903f5 commit 494944e

File tree

14 files changed

+84
-46
lines changed

14 files changed

+84
-46
lines changed

tests/jax/test_distributed_fused_attn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def impl_test_self_attn(
6868
batch, seqlen, num_head, hidden = data_shape
6969

7070
if not is_fused_attn_kernel_available(
71+
is_training,
7172
dtype,
7273
dtype,
7374
QKVLayout.BS3HD,
@@ -214,6 +215,7 @@ def test_cross_attn(
214215
batch, seqlen, num_head, hidden = data_shape
215216

216217
if not is_fused_attn_kernel_available(
218+
is_training,
217219
dtype,
218220
dtype,
219221
QKVLayout.BSHD_BS2HD,
@@ -345,6 +347,7 @@ def impl_test_context_parallel_attn(
345347

346348
def check_has_backend_for_mask(mask_type):
347349
return is_fused_attn_kernel_available(
350+
is_training,
348351
dtype,
349352
dtype,
350353
qkv_layout,

tests/jax/test_fused_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def _check_configs(self):
347347
)
348348

349349
self.backend = FusedAttnHelper(
350+
self.is_training,
350351
self.dtype,
351352
self.dtype,
352353
self.qkv_layout,

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -211,13 +211,13 @@ def test():
211211
return available_backends, flash_attention_backend, fused_attention_backend
212212

213213
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
214-
with logging_context():
215-
for i in range(3):
216-
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
217-
_attention_backends["backend_selection_requires_update"] = True
218-
available_backends, flash_attention_backend, fused_attention_backend = test()
219-
if fused_attention_backend == FusedAttnBackend[backends[i]]:
220-
fused_attn_backends.append(fused_attention_backend)
214+
#with logging_context():
215+
for i in range(3):
216+
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
217+
_attention_backends["backend_selection_requires_update"] = True
218+
available_backends, flash_attention_backend, fused_attention_backend = test()
219+
if fused_attention_backend == FusedAttnBackend[backends[i]]:
220+
fused_attn_backends.append(fused_attention_backend)
221221
return available_backends, flash_attention_backend, fused_attn_backends
222222

223223

@@ -229,6 +229,12 @@ def test():
229229
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
230230
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
231231
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
232+
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
233+
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), # inference
234+
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
235+
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), # inference
236+
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
237+
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), # inference
232238
}
233239

234240

@@ -270,12 +276,15 @@ def test_dot_product_attention(
270276
if config.window_size == (-1, -1) and swa:
271277
config.window_size = [2, 2]
272278
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)
279+
280+
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
273281
available_backends, _, fused_attn_backends = _get_attention_backends(
274282
config,
275283
qkv_dtype=dtype,
276284
qkv_layout=qkv_layout,
277285
window_size=config.window_size,
278286
pad_between_seqs=pad_between_seqs,
287+
is_training=is_training,
279288
)
280289
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
281290

@@ -296,7 +305,6 @@ def test_dot_product_attention(
296305
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
297306
pytest.skip("Less than two backends to compare.")
298307

299-
is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
300308
# UnfusedDotProductAttention backend
301309
if unfused_attn_supported:
302310
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
@@ -1024,6 +1032,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
10241032
layer_number=1,
10251033
attention_type=config.attn_type,
10261034
).to(dtype=dtype, device="cuda")
1035+
if not is_training:
1036+
block = block.eval()
10271037

10281038
# Run a forward and backward pass
10291039
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
@@ -1367,6 +1377,8 @@ def _run_transformer_layer(
13671377
bias=True,
13681378
attn_input_format=qkv_format,
13691379
).to(dtype=dtype, device="cuda")
1380+
if not is_training:
1381+
block = block.eval()
13701382

13711383
# Create ALiBi slopes
13721384
alibi_slopes = None
@@ -1384,8 +1396,9 @@ def _run_transformer_layer(
13841396
core_attention_bias=bias,
13851397
alibi_slopes=alibi_slopes,
13861398
)
1387-
loss = out.sum()
1388-
loss.backward()
1399+
if is_training:
1400+
loss = out.sum()
1401+
loss.backward()
13891402

13901403
return out, inp.grad
13911404

tests/pytorch/fused_attn/test_kv_cache.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@
4848

4949
model_configs_infer = {
5050
# test: b, h, hg, d, sq, skv, p, mask, bias
51-
"infer_0": ModelConfig(
52-
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
53-
),
51+
#"infer_0": ModelConfig(
52+
# 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
53+
#),
5454
"infer_1": ModelConfig(
55-
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
55+
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
5656
),
5757
}
5858

@@ -370,12 +370,18 @@ def generate_args(
370370
]
371371

372372

373-
def get_tols(module, backend, dtype):
373+
def get_tols(config, module, backend, dtype):
374374
if module == "TransformerLayer":
375-
tols = {
376-
torch.half: (5e-3, 5e-3),
377-
torch.bfloat16: (3.5e-2, 3.5e-2),
378-
}
375+
if config.head_dim_qk <= 128:
376+
tols = {
377+
torch.half: (5e-3, 5e-3),
378+
torch.bfloat16: (3.5e-2, 3.5e-2),
379+
}
380+
else:
381+
tols = {
382+
torch.half: (7e-3, 7e-3),
383+
torch.bfloat16: (5e-2, 5e-2),
384+
}
379385
if module == "DotProductAttention":
380386
tols = {
381387
torch.half: (1e-3, 1e-3),
@@ -484,6 +490,8 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
484490
# TransformerLayer FP8 TN Gemm currently requires %8=0
485491
if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"):
486492
pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported")
493+
if backend == "FusedAttention" and config.head_dim_qk > 128 and not is_paged and not is_cuda_graph:
494+
pytest.skip("No support for KV caching with head dim > 128, non-paged attention, sq = 1, and mask != no_mask")
487495

488496
# create full model
489497
logger.info("=== Generating all tokens at once ===")
@@ -662,7 +670,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
662670
incremental_output = incremental_output[0]
663671

664672
# compare results
665-
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
673+
atol, rtol = get_tols(config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
666674
for i, seq in enumerate(sim.t_seq_ids):
667675
token_index = sim.step_lens[i] - 1
668676
if qkv_format == "bshd":

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
134134

135135
// select a backend for fused attention
136136
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
137-
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
137+
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
138138
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
139139
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
140140
int64_t window_size_left, int64_t window_size_right) {
@@ -216,24 +216,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
216216
}
217217
if (
218218
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
219-
// special conditions for blackwell
220-
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
221-
!(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
222219
// architecture
223-
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
224-
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
220+
((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) ||
221+
(cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) ||
222+
(cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) &&
225223
// sequence length
226224
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
227225
(cudnn_runtime_version >= 90000)) &&
228226
// number of heads
229227
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
230228
(cudnn_runtime_version >= 8907)) &&
231229
// head dimension
232-
((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) ||
233-
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
234-
// d=256 only supported for forward
235-
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 &&
236-
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) &&
230+
(head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
231+
((head_dim_qk <= 128 && head_dim_v <= 128) ||
232+
(head_dim_qk <= 256 && head_dim_v <= 256 &&
233+
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
234+
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
235+
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 &&
236+
max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
237+
(!is_training && cudnn_runtime_version >= 91000 &&
238+
(max_seqlen_q > 1 ||
239+
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD ||
240+
(max_seqlen_q == 1 && sm_arch_ >= 100 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) &&
237241
// bias type
238242
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
239243
(cudnn_runtime_version >= 8906 &&
@@ -423,7 +427,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
423427
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
424428

425429
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
426-
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
430+
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
427431
max_seqlen, d, d, window_size_left, window_size_right);
428432

429433
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
@@ -505,7 +509,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
505509
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
506510

507511
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
508-
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
512+
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
509513
max_seqlen, d, d, window_size_left, window_size_right);
510514

511515
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
@@ -636,7 +640,7 @@ void nvte_fused_attn_fwd_kvpacked(
636640
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
637641

638642
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
639-
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
643+
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
640644
max_seqlen_kv, d, d, window_size_left, window_size_right);
641645

642646
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
@@ -731,7 +735,7 @@ void nvte_fused_attn_bwd_kvpacked(
731735
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
732736

733737
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
734-
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
738+
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
735739
max_seqlen_kv, d, d, window_size_left, window_size_right);
736740

737741
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
@@ -862,7 +866,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
862866
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
863867

864868
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
865-
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
869+
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
866870
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
867871

868872
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
@@ -954,7 +958,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
954958
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
955959

956960
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
957-
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
961+
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
958962
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
959963

960964
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {

transformer_engine/common/include/transformer_engine/fused_attn.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
172172

173173
/*! \brief Get fused attention backend based on input parameters.
174174
*
175+
* \param[in] is_training Whether the model is in training mode.
175176
* \param[in] q_dtype The data type of Tensor Q.
176177
* \param[in] kv_dtype The data type of Tensors K, V.
177178
* \param[in] qkv_layout The layout of Tensors Q, K, V.
@@ -188,7 +189,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
188189
* \param[in] window_size_right Sliding window size (the right half).
189190
*/
190191
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
191-
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
192+
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
192193
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
193194
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
194195
int64_t window_size_left, int64_t window_size_right);

transformer_engine/jax/attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str):
277277

278278

279279
def is_fused_attn_kernel_available(
280+
is_training,
280281
q_dtype,
281282
kv_dtype,
282283
qkv_layout,
@@ -296,6 +297,7 @@ def is_fused_attn_kernel_available(
296297

297298
def make_helper(attn_mask_type):
298299
return tex.FusedAttnHelper(
300+
is_training,
299301
q_dtype,
300302
kv_dtype,
301303
qkv_layout,

transformer_engine/jax/cpp_extensions/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class FusedAttnHelper:
100100
Helper for the fused attention backend
101101
"""
102102

103+
is_training: bool
103104
q_dtype: jnp.dtype
104105
kv_dtype: jnp.dtype
105106
qkv_layout: QKVLayout
@@ -120,6 +121,7 @@ def is_fused_attn_kernel_available(self):
120121
def get_fused_attn_backend(self):
121122
"""Get the fused attention kernel backend"""
122123
return transformer_engine_jax.get_fused_attn_backend(
124+
self.is_training,
123125
jax_dtype_to_te_dtype(self.q_dtype),
124126
jax_dtype_to_te_dtype(self.kv_dtype),
125127
self.qkv_layout.value,
@@ -273,6 +275,7 @@ def abstract(
273275

274276
# backend determines the softmax buffer shape/dtype
275277
backend = FusedAttnHelper(
278+
config.is_training,
276279
q_dtype,
277280
k_dtype,
278281
config.qkv_layout,

transformer_engine/jax/csrc/extensions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);
9393

9494
XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);
9595

96-
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
96+
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
9797
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
9898
NVTE_Mask_Type mask_type, float dropout_probability,
9999
size_t q_num_heads, size_t kv_num_heads,

transformer_engine/jax/csrc/extensions/attention.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
namespace transformer_engine {
1111
namespace jax {
1212

13-
NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
13+
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
1414
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
1515
NVTE_Mask_Type mask_type, float dropout_probability,
1616
size_t q_attn_heads, size_t kv_attn_heads,
1717
size_t q_max_seqlen, size_t kv_max_seqlen,
1818
size_t head_dim, int64_t window_size_left,
1919
int64_t window_size_right) {
2020
auto backend = nvte_get_fused_attn_backend(
21-
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
21+
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
2222
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
2323
head_dim, head_dim, window_size_left, window_size_right);
2424
return backend;
@@ -245,7 +245,7 @@ static void FusedAttnForwardImpl(
245245
/* Prepare RNG state */
246246
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
247247
auto backend = nvte_get_fused_attn_backend(
248-
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
248+
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
249249
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
250250
head_dim, head_dim, window_size_left, window_size_right);
251251
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
@@ -498,7 +498,7 @@ static void FusedAttnBackwardImpl(
498498
NVTETensorPack aux_input_tensors;
499499
nvte_tensor_pack_create(&aux_input_tensors);
500500
auto backend = nvte_get_fused_attn_backend(
501-
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
501+
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
502502
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
503503
head_dim, head_dim, window_size_left, window_size_right);
504504
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,

transformer_engine/jax/flax/transformer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,8 @@ def __call__(
596596
seqlen_kv = key.shape[sequence_dim]
597597

598598
has_fused_attn_kernel = is_fused_attn_kernel_available(
599+
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
600+
not deterministic,
599601
self.dtype,
600602
self.dtype,
601603
qkv_layout,

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,7 @@ def get_attention_backend(
766766
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
767767
kv_type = q_type
768768
fused_attention_backend = tex.get_fused_attn_backend(
769+
is_training,
769770
q_type,
770771
kv_type,
771772
QKVLayout[qkv_layout],

0 commit comments

Comments
 (0)