Skip to content

[PyTorch] Miscellaneous fixes for attention #1780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 87 additions & 39 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,9 +1107,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
Expand All @@ -1120,7 +1122,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(
Expand All @@ -1137,13 +1139,18 @@ def test_transformer_layer(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD requires padding mask.")

# UnfusedDotProductAttention backend
if unfused_attn_supported:
Expand Down Expand Up @@ -1264,48 +1271,84 @@ def _run_transformer_layer(
_attention_backends["backend_selection_requires_update"] = True

# Create input tensor
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
if qkv_format == "sbhd":
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.max_seqlen_kv,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
if qkv_format == "bshd":
inp = inp.transpose(0, 1)
inp = torch.randn(
config.batch_size,
config.max_seqlen_q,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.batch_size,
config.max_seqlen_kv,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)

# Create seqlens
if "padding" in config.attn_mask_type:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)

# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda")
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
if qkv_format == "thd":
inp = torch.randn(
cu_seqlens_q[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
cu_seqlens_kv[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)

sigma = 0.02
init_method = init_method_normal(sigma)
Expand Down Expand Up @@ -1357,7 +1400,7 @@ def _run_transformer_layer(
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
layer_type="encoder" if config.attn_type == "self" else "decoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
Expand All @@ -1376,13 +1419,18 @@ def _run_transformer_layer(
# Run a forward and backward pass
out = block(
inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
encoder_output=inp_enc if config.attn_type == "cross" else None,
enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
loss = out.sum()
loss.backward()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,12 @@ def forward(

if "padding" in attn_mask_type and attention_mask is None:
attention_mask = dpa_utils.get_padding_mask(
batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
batch_size,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
self.attention_type,
)
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask(
Expand Down
53 changes: 33 additions & 20 deletions transformer_engine/pytorch/attention/dot_product_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,16 +946,24 @@ def get_attention_backend(
@torch.no_grad()
def get_padding_mask(
batch_size: int,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
cu_seqlens_q: torch.Tensor = None,
cu_seqlens_kv: torch.Tensor = None,
max_seqlen_q: int = None,
max_seqlen_kv: int = None,
attention_type: str = "self",
):
"""Convert cu_seqlens to attention_mask"""
assert (
cu_seqlens_q is not None and max_seqlen_q is not None
), "cu_seqlens_q and max_seqlen_q are required for self-attention and cross-attention"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
if attention_type == "cross":
assert (
cu_seqlens_kv is not None and max_seqlen_kv is not None
), "cu_seqlens_kv and max_seqlen_kv are required for cross-attention"
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(batch_size):
attention_mask_q = torch.cat(
[
Expand All @@ -968,21 +976,26 @@ def get_padding_mask(
],
dim=0,
)
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i]))
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
if attention_type == "cross":
attention_mask_kv = torch.cat(
[
attention_mask_kv,
torch.Tensor([False] * seqlens_kv[i] + [True] * (max_seqlen_kv - seqlens_kv[i]))
.to(dtype=torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask_q = attention_mask_q.to(device="cuda")
if attention_type == "self":
attention_mask = attention_mask_q
else:
attention_mask = (
attention_mask_q,
attention_mask_kv.to(device="cuda"),
)
attention_mask = (
attention_mask_q.to(device="cuda"),
attention_mask_kv.to(device="cuda"),
)
return attention_mask


Expand Down
22 changes: 22 additions & 0 deletions transformer_engine/pytorch/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,8 @@ def forward(
alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True,
Expand Down Expand Up @@ -556,6 +558,12 @@ def forward(
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
Expand Down Expand Up @@ -714,6 +722,18 @@ def forward(
for x in (key_layer, value_layer)
)

if self.qkv_format == "thd":
key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
else:
# key, value: -> [sq, b, ng, hn]
key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)

# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query(
Expand Down Expand Up @@ -803,6 +823,8 @@ def forward(
qkv_format=self.qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
attention_mask=attention_mask,
Expand Down
Loading