Skip to content

Add support for head_dim > 128 #1797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
494944e
add support for head dim > 128
cyanguwa May 17, 2025
750dd91
remove debugging
cyanguwa May 18, 2025
517030d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 18, 2025
049b19f
raise tols slightly to tolerate 1/2048 mismatches
cyanguwa May 24, 2025
d363bc3
Merge branch 'main' into d_256
cyanguwa May 24, 2025
66e912b
fix is_training for test_te_layer
cyanguwa May 25, 2025
a070c8c
Merge branch 'main' into d_256
cyanguwa Jun 3, 2025
e323399
add bprop support for blackwell
cyanguwa Jun 3, 2025
2dbdafa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2025
39e54eb
minor tweak for format
cyanguwa Jun 3, 2025
0ce126c
fix backend selection results
cyanguwa Jun 3, 2025
0fdbb2c
bump sm100 to sm100+
cyanguwa Jun 3, 2025
3596a36
add sq=1 test for MLA
cyanguwa Jun 3, 2025
78e2e93
enable sq=1 for bprop
cyanguwa Jun 3, 2025
bc75c84
minor tweak in comments
cyanguwa Jun 4, 2025
bb0e83a
Merge branch 'NVIDIA:main' into d_256
cyanguwa Jun 4, 2025
39248e8
fix head_dim logic and remove pytest skip
cyanguwa Jun 10, 2025
6f6ce66
Merge branch 'main' into d_256
cyanguwa Jun 11, 2025
408e110
add FE fix for d>128
cyanguwa Jun 12, 2025
3d2b1fb
Merge branch 'main' into d_256
cyanguwa Jun 12, 2025
bbe9ef9
update FE again to take in small fixes
cyanguwa Jun 12, 2025
3e1b426
add cuDNN version info in L0 tests
cyanguwa Jun 12, 2025
0e36eeb
increase tols for Unfused + large dim
cyanguwa Jun 12, 2025
70647e3
Merge branch 'main' into d_256
cyanguwa Jun 12, 2025
e795580
Revert "add cuDNN version info in L0 tests"
cyanguwa Jun 12, 2025
d8bba8e
fix tols for Unfused
cyanguwa Jun 13, 2025
57ff6a6
Merge branch 'main' into d_256
cyanguwa Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def impl_test_self_attn(
batch, seqlen, num_head, hidden = data_shape

if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BS3HD,
Expand Down Expand Up @@ -214,6 +215,7 @@ def test_cross_attn(
batch, seqlen, num_head, hidden = data_shape

if not is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
QKVLayout.BSHD_BS2HD,
Expand Down Expand Up @@ -346,6 +348,7 @@ def impl_test_context_parallel_attn(

def check_has_backend_for_mask(mask_type):
return is_fused_attn_kernel_available(
is_training,
dtype,
dtype,
qkv_layout,
Expand Down
1 change: 1 addition & 0 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ def _check_configs(self):
)

self.backend = FusedAttnHelper(
self.is_training,
self.dtype,
self.dtype,
self.qkv_layout,
Expand Down
74 changes: 64 additions & 10 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,19 @@ def test():


model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
# test: b, h, hg, d, sq, skv, p, mask, bias
"base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
"base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
"base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
"base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"),
"base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"),
"base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"),
}


Expand Down Expand Up @@ -270,14 +276,28 @@ def test_dot_product_attention(
if config.window_size == (-1, -1) and swa:
config.window_size = [2, 2]
config.window_size = check_set_window_size(config.attn_mask_type, config.window_size)

is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# FlashAttention does not support pad_between_seqs, but _run_dot_product_attention
# mannually pads and unpads the input and output of FlashAttention for testing purposes
Expand All @@ -296,7 +316,6 @@ def test_dot_product_attention(
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")

is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
Expand Down Expand Up @@ -360,6 +379,7 @@ def test_dot_product_attention(
is_training,
)

logging.info(f"[test_dot_product_attention]: is_training = {is_training}")
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
Expand Down Expand Up @@ -399,18 +419,27 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"mla_1_1": ModelConfig(
4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_1_2": ModelConfig(
4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128
), # cross, 0
"mla_2_0": ModelConfig(
2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64
), # self , 1
"mla_2_1": ModelConfig(
1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64
), # cross, 1
"mla_2_2": ModelConfig(
1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128
), # cross, 1
"mla_3_0": ModelConfig(
8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64
), # inference
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
"mla_3_2": ModelConfig(
8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
}


Expand Down Expand Up @@ -1024,6 +1053,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
layer_number=1,
attention_type=config.attn_type,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()

# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
Expand Down Expand Up @@ -1136,14 +1167,29 @@ def test_transformer_layer(
workspace_opt = True

# Test backend availability
is_training = True
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not fused_attn_supported:
is_training = False
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=(
qkv_format.replace("hd", "h3d")
if fused_qkv_params
else qkv_format.replace("hd", "3hd")
),
is_training=is_training,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
Expand All @@ -1163,6 +1209,7 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)

# FusedAttention backend
Expand All @@ -1176,6 +1223,7 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)

# FlashAttention backend
Expand All @@ -1189,8 +1237,10 @@ def test_transformer_layer(
workspace_opt,
fused_qkv_params,
RoPE,
is_training,
)

logging.info(f"[test_transformer_layer]: is_training = {is_training}")
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_transformer_layer]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
Expand Down Expand Up @@ -1257,6 +1307,7 @@ def _run_transformer_layer(
workspace_opt: bool,
fused_qkv_params: bool,
RoPE: bool,
is_training: bool,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""

Expand Down Expand Up @@ -1410,6 +1461,8 @@ def _run_transformer_layer(
bias=True,
attn_input_format=qkv_format,
).to(dtype=dtype, device="cuda")
if not is_training:
block = block.eval()

# Create ALiBi slopes
alibi_slopes = None
Expand All @@ -1432,8 +1485,9 @@ def _run_transformer_layer(
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
loss = out.sum()
loss.backward()
if is_training:
loss = out.sum()
loss.backward()

return out, inp.grad

Expand Down
28 changes: 21 additions & 7 deletions tests/pytorch/fused_attn/test_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16
),
"infer_1": ModelConfig(
2, 16, 4, 64, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16
),
}

Expand Down Expand Up @@ -370,12 +370,24 @@ def generate_args(
]


def get_tols(module, backend, dtype):
def get_tols(config, module, backend, dtype):
if module == "TransformerLayer":
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
if config.head_dim_qk <= 128:
tols = {
torch.half: (5e-3, 5e-3),
torch.bfloat16: (3.5e-2, 3.5e-2),
}
else:
if backend == "UnfusedAttention":
tols = {
torch.half: (1.6e-2, 1.6e-2),
torch.bfloat16: (1.2e-1, 1e-1),
}
else:
tols = {
torch.half: (1e-2, 1e-2),
torch.bfloat16: (8e-2, 7e-2),
}
if module == "DotProductAttention":
tols = {
torch.half: (1e-3, 1e-3),
Expand Down Expand Up @@ -662,7 +674,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g
incremental_output = incremental_output[0]

# compare results
atol, rtol = get_tols(module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn)
atol, rtol = get_tols(
config, module, backend, dtype=dtype if not is_fp8 else torch.float8_e4m3fn
)
for i, seq in enumerate(sim.t_seq_ids):
token_index = sim.step_lens[i] - 1
if qkv_format == "bshd":
Expand Down
Loading