Skip to content

Add support for head_dim > 128 #1797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def impl_test_self_attn(
batch, seqlen, num_head, hidden = data_shape

if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BS3HD,
Expand Down Expand Up @@ -214,6 +215,7 @@ def test_cross_attn(
batch, seqlen, num_head, hidden = data_shape

if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
Expand Down Expand Up @@ -345,6 +347,7 @@ def impl_test_context_parallel_attn(

def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
qkv_layout,
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _check_configs(self):
)

self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
self.dtype,
self.qkv_layout,
Expand Down
19 changes: 16 additions & 3 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ def test():
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), # inference
}


Expand Down Expand Up @@ -270,12 +276,15 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)

is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

Expand All @@ -296,7 +305,6 @@ def test_dot_product_attention(
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")

is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
Expand Down Expand Up @@ -1024,6 +1032,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
Expand Down Expand Up @@ -1367,6 +1377,8 @@ def _run_transformer_layer(
bias=True,
attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()

# Create ALiBi slopes
alibi_slopes = None
Expand All @@ -1384,8 +1396,9 @@ def _run_transformer_layer(
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
)
loss = out.sum()
loss.backward()
if is_training:
loss = out.sum()
loss.backward()

return out, inp.grad

Expand Down
32 changes: 25 additions & 7 deletions tests/pytorch/fused_attn/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}

Expand Down Expand Up @@ -370,12 +370,18 @@ def generate_args(
]


def get_tols(module, backend, dtype):
def get_tols(config, module, backend, dtype):
if module == "TransformerLayer":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
if config.head_dim_qk <= 128:
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
else:
tols = {
torch.half: (7e-3, 7e-3),
torch.bfloat16: (5e-2, 5e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
Expand Down Expand Up @@ -484,6 +490,16 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
# TransformerLayer FP8 TN Gemm currently requires %8=0
if is_fp8 and not (qkv_format == "thd" and module == "DotProductAttention"):
pytest.skip("BSHD/SBHD <-> THD conversions for FP8 are not supported")
if (
backend == "FusedAttention"
and config.head_dim_qk > 128
and not is_paged
and not is_cuda_graph
):
pytest.skip(
"No support for KV caching with head dim > 128, non-paged attention, sq = 1, and mask"
" != no_mask"
)

# create full model
logger.info("=== Generating all tokens at once ===")
Expand Down Expand Up @@ -662,7 +678,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
incremental_output = incremental_output[0]

# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
atol, rtol = get_tols(
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
Expand Down
54 changes: 29 additions & 25 deletions transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) {

// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right) {
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
Expand Down Expand Up @@ -216,24 +216,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
if (
// TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging
// special conditions for blackwell
// TODO: enable THD max_t in f16_arbitrary_seqlen when support becomes available in 9.7
!(sm_arch_ >= 100 && (head_dim_qk > 128 || head_dim_v > 128)) &&
// architecture
((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
(cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
((cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)) ||
(cudnn_runtime_version >= 8903 && sm_arch_ >= 80 && sm_arch_ < 100) ||
(cudnn_runtime_version >= 90700 && sm_arch_ >= 80)) &&
// sequence length
((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
(cudnn_runtime_version >= 90000)) &&
// number of heads
((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
(cudnn_runtime_version >= 8907)) &&
// head dimension
((head_dim_qk <= 128 && head_dim_qk % 8 == 0 && head_dim_v <= 128 && head_dim_v % 8 == 0) ||
// TODO (cyang): add is_training to nvte_get_fused_attn_backend
// d=256 only supported for forward
(sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim_qk <= 256 &&
head_dim_qk % 8 == 0 && head_dim_v <= 256 && head_dim_v % 8 == 0)) &&
(head_dim_qk % 8 == 0 && head_dim_v % 8 == 0 &&
((head_dim_qk <= 128 && head_dim_v <= 128) ||
(head_dim_qk <= 256 && head_dim_v <= 256 &&
((!is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90100) ||
(is_training && sm_arch_ == 90 && cudnn_runtime_version >= 90500))) ||
(!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) ||
(!is_training && cudnn_runtime_version >= 91000 &&
(max_seqlen_q > 1 || layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD ||
(max_seqlen_q == 1 && sm_arch_ >= 100 &&
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) &&
// bias type
((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
(cudnn_runtime_version >= 8906 &&
Expand Down Expand Up @@ -423,8 +427,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right);
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Expand Down Expand Up @@ -505,7 +509,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con
const NVTEDType QKV_type = static_cast<NVTEDType>(input_QKV->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen,
max_seqlen, d, d, window_size_left, window_size_right);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
Expand Down Expand Up @@ -636,8 +640,8 @@ void nvte_fused_attn_fwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, d, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Expand Down Expand Up @@ -731,8 +735,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d, d, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Expand Down Expand Up @@ -862,8 +866,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Expand Down Expand Up @@ -954,8 +958,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
const NVTEDType KV_type = static_cast<NVTEDType>(input_K->data.dtype);

NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q,
max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv,
max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right);

if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);

/*! \brief Get fused attention backend based on input parameters.
*
* \param[in] is_training Whether the model is in training mode.
* \param[in] q_dtype The data type of Tensor Q.
* \param[in] kv_dtype The data type of Tensors K, V.
* \param[in] qkv_layout The layout of Tensors Q, K, V.
Expand All @@ -188,10 +189,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] window_size_right Sliding window size (the right half).
*/
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads,
size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);

/*! \brief Compute dot product attention with packed QKV input.
*
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/jax/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def canonicalize_attn_mask_type(attn_mask_type: str):


def is_fused_attn_kernel_available(
is_training,
q_dtype,
kv_dtype,
qkv_layout,
Expand All @@ -296,6 +297,7 @@ def is_fused_attn_kernel_available(

def make_helper(attn_mask_type):
return tex.FusedAttnHelper(
is_training,
q_dtype,
kv_dtype,
qkv_layout,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class FusedAttnHelper:
Helper for the fused attention backend
"""

is_training: bool
q_dtype: jnp.dtype
kv_dtype: jnp.dtype
qkv_layout: QKVLayout
Expand All @@ -120,6 +121,7 @@ def is_fused_attn_kernel_available(self):
def get_fused_attn_backend(self):
"""Get the fused attention kernel backend"""
return transformer_engine_jax.get_fused_attn_backend(
self.is_training,
jax_dtype_to_te_dtype(self.q_dtype),
jax_dtype_to_te_dtype(self.kv_dtype),
self.qkv_layout.value,
Expand Down Expand Up @@ -273,6 +275,7 @@ def abstract(

# backend determines the softmax buffer shape/dtype
backend = FusedAttnHelper(
config.is_training,
q_dtype,
k_dtype,
config.qkv_layout,
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnForwardHandler);

XLA_FFI_DECLARE_HANDLER_SYMBOL(FusedAttnBackwardHandler);

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_num_heads, size_t kv_num_heads,
Expand Down
20 changes: 10 additions & 10 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
namespace transformer_engine {
namespace jax {

NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype,
NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DType kv_dtype,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, float dropout_probability,
size_t q_attn_heads, size_t kv_attn_heads,
size_t q_max_seqlen, size_t kv_max_seqlen,
size_t head_dim, int64_t window_size_left,
int64_t window_size_right) {
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout,
bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
return backend;
}

Expand Down Expand Up @@ -245,9 +245,9 @@ static void FusedAttnForwardImpl(
/* Prepare RNG state */
auto rng_state_tensor = TensorWrapper(rng_state, std::vector<size_t>{2}, DType::kInt64);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream);

/* Auxiliary tensors (to be propagated to the backward pass later) */
Expand Down Expand Up @@ -498,9 +498,9 @@ static void FusedAttnBackwardImpl(
NVTETensorPack aux_input_tensors;
nvte_tensor_pack_create(&aux_input_tensors);
auto backend = nvte_get_fused_attn_backend(
static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout, bias_type,
mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen,
head_dim, head_dim, window_size_left, window_size_right);
is_training, static_cast<NVTEDType>(dtype), static_cast<NVTEDType>(dtype), qkv_layout,
bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen,
kv_max_seqlen, head_dim, head_dim, window_size_left, window_size_right);
PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads,
bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend,
softmax_aux, rng_state, bias);
Expand Down
Loading