Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from NVIDIA:main #47

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,10 @@ def run_dpa_with_cp(
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
Expand Down Expand Up @@ -204,10 +202,8 @@ def run_dpa_with_cp(
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
Expand Down Expand Up @@ -276,10 +272,8 @@ def run_dpa_with_cp(
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
Expand Down Expand Up @@ -311,7 +305,7 @@ def run_dpa_with_cp(
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
Expand All @@ -327,7 +321,7 @@ def run_dpa_with_cp(
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
Expand Down
10 changes: 2 additions & 8 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
Expand All @@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
Expand Down
Loading
Loading