Skip to content

Commit 4d9b33a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent af5aec9 commit 4d9b33a

File tree

7 files changed

+65
-57
lines changed

7 files changed

+65
-57
lines changed

tests/pytorch/fused_attn/test_fused_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def test():
211211
return available_backends, flash_attention_backend, fused_attention_backend
212212

213213
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
214-
#with logging_context():
214+
# with logging_context():
215215
for i in range(3):
216216
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
217217
_attention_backends["backend_selection_requires_update"] = True

tests/pytorch/fused_attn/test_kv_cache.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@
4848

4949
model_configs_infer = {
5050
# test: b, h, hg, d, sq, skv, p, mask, bias
51-
#"infer_0": ModelConfig(
51+
# "infer_0": ModelConfig(
5252
# 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
53-
#),
53+
# ),
5454
"infer_1": ModelConfig(
5555
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
5656
),
@@ -490,8 +490,16 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
490490
# TransformerLayer FP8 TN Gemm currently requires %8=0
491491
if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"):
492492
pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported")
493-
if backend == "FusedAttention" and config.head_dim_qk > 128 and not is_paged and not is_cuda_graph:
494-
pytest.skip("No support for KV caching with head dim > 128, non-paged attention, sq = 1, and mask != no_mask")
493+
if (
494+
backend == "FusedAttention"
495+
and config.head_dim_qk > 128
496+
and not is_paged
497+
and not is_cuda_graph
498+
):
499+
pytest.skip(
500+
"No support for KV caching with head dim > 128, non-paged attention, sq = 1, and mask"
501+
" != no_mask"
502+
)
495503

496504
# create full model
497505
logger.info("=== Generating all tokens at once ===")
@@ -670,7 +678,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
670678
incremental_output = incremental_output[0]
671679

672680
# compare results
673-
atol, rtol = get_tols(config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
681+
atol, rtol = get_tols(
682+
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
683+
)
674684
for i, seq in enumerate(sim.t_seq_ids):
675685
token_index = sim.step_lens[i] - 1
676686
if qkv_format == "bshd":

transformer_engine/common/fused_attn/fused_attn.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {
134134

135135
// select a backend for fused attention
136136
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
137-
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
138-
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
139-
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
140-
int64_t window_size_left, int64_t window_size_right) {
137+
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
138+
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
139+
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
140+
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
141141
using namespace transformer_engine;
142142
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
143143
const int device_id = cuda::current_device();
@@ -228,16 +228,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
228228
(cudnn_runtime_version >= 8907)) &&
229229
// head dimension
230230
(head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
231-
((head_dim_qk <= 128 && head_dim_v <= 128) ||
232-
(head_dim_qk <= 256 && head_dim_v <= 256 &&
233-
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
234-
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
235-
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 &&
236-
max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
237-
(!is_training && cudnn_runtime_version >= 91000 &&
238-
(max_seqlen_q > 1 ||
239-
layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD ||
240-
(max_seqlen_q == 1 && sm_arch_ >= 100 && attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) &&
231+
((head_dim_qk <= 128 && head_dim_v <= 128) ||
232+
(head_dim_qk <= 256 && head_dim_v <= 256 &&
233+
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
234+
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
235+
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
236+
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
237+
(!is_training && cudnn_runtime_version >= 91000 &&
238+
(max_seqlen_q > 1 || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD ||
239+
(max_seqlen_q == 1 && sm_arch_ >= 100 &&
240+
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) &&
241241
// bias type
242242
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
243243
(cudnn_runtime_version >= 8906 &&
@@ -427,8 +427,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
427427
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);
428428

429429
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
430-
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
431-
max_seqlen, d, d, window_size_left, window_size_right);
430+
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
431+
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);
432432

433433
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
434434
#if (CUDNN_VERSION >= 8901)
@@ -640,8 +640,8 @@ void nvte_fused_attn_fwd_kvpacked(
640640
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
641641

642642
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
643-
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
644-
max_seqlen_kv, d, d, window_size_left, window_size_right);
643+
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
644+
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
645645

646646
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
647647
#if (CUDNN_VERSION >= 8901)
@@ -735,8 +735,8 @@ void nvte_fused_attn_bwd_kvpacked(
735735
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
736736

737737
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
738-
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
739-
max_seqlen_kv, d, d, window_size_left, window_size_right);
738+
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
739+
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);
740740

741741
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
742742
#if (CUDNN_VERSION >= 8901)
@@ -866,8 +866,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
866866
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
867867

868868
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
869-
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
870-
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
869+
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
870+
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
871871

872872
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
873873
#if (CUDNN_VERSION >= 8901)
@@ -958,8 +958,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
958958
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);
959959

960960
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
961-
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
962-
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
961+
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
962+
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
963963

964964
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
965965
#if (CUDNN_VERSION >= 8901)

transformer_engine/common/include/transformer_engine/fused_attn.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
189189
* \param[in] window_size_right Sliding window size (the right half).
190190
*/
191191
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
192-
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
193-
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
194-
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
195-
int64_t window_size_left, int64_t window_size_right);
192+
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
193+
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
194+
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
195+
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
196196

197197
/*! \brief Compute dot product attention with packed QKV input.
198198
*

transformer_engine/jax/csrc/extensions/attention.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy
1818
size_t head_dim, int64_t window_size_left,
1919
int64_t window_size_right) {
2020
auto backend = nvte_get_fused_attn_backend(
21-
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
22-
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
23-
head_dim, head_dim, window_size_left, window_size_right);
21+
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
22+
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
23+
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
2424
return backend;
2525
}
2626

@@ -245,9 +245,9 @@ static void FusedAttnForwardImpl(
245245
/* Prepare RNG state */
246246
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
247247
auto backend = nvte_get_fused_attn_backend(
248-
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
249-
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
250-
head_dim, head_dim, window_size_left, window_size_right);
248+
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
249+
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
250+
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
251251
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);
252252

253253
/* Auxiliary tensors (to be propagated to the backward pass later) */
@@ -498,9 +498,9 @@ static void FusedAttnBackwardImpl(
498498
NVTETensorPack aux_input_tensors;
499499
nvte_tensor_pack_create(&aux_input_tensors);
500500
auto backend = nvte_get_fused_attn_backend(
501-
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
502-
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
503-
head_dim, head_dim, window_size_left, window_size_right);
501+
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
502+
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
503+
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
504504
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
505505
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
506506
softmax_aux, rng_state, bias);

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
3535
* Attention
3636
**************************************************************************************************/
3737

38-
NVTE_Fused_Attn_Backend get_fused_attn_backend(bool is_training, const DType q_dtype, const DType kv_dtype,
39-
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
40-
NVTE_Mask_Type attn_mask_type, float p_dropout,
41-
size_t num_attn_heads, size_t num_gqa_groups,
42-
size_t max_seqlen_q, size_t max_seqlen_kv,
43-
size_t head_dim_qk, size_t head_dim_v,
44-
int64_t window_size_left, int64_t window_size_right);
38+
NVTE_Fused_Attn_Backend get_fused_attn_backend(
39+
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
40+
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
41+
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
42+
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
4543

4644
std::vector<py::object> fused_attn_fwd(
4745
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,

transformer_engine/pytorch/csrc/extensions/attention.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ namespace transformer_engine::pytorch {
5757

5858
// get the fused attention backend
5959
NVTE_Fused_Attn_Backend get_fused_attn_backend(
60-
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
61-
NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups,
62-
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
63-
int64_t window_size_left, int64_t window_size_right) {
60+
bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
61+
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
62+
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
63+
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
6464
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
65-
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
66-
attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
67-
head_dim_qk, head_dim_v, window_size_left, window_size_right);
65+
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
66+
bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q,
67+
max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right);
6868
return fused_attention_backend;
6969
}
7070

0 commit comments

Comments
 (0)