Skip to content

Misc 2.4 #1780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
180 changes: 144 additions & 36 deletions tests/pytorch/fused_attn/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,9 +1107,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
"te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
"te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
"te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
"te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
"te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
"te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
"te_2_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
"te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
Expand All @@ -1120,7 +1122,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@pytest.mark.parametrize("ckpt_attn", [False])
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
def test_transformer_layer(
Expand All @@ -1137,13 +1139,18 @@ def test_transformer_layer(
available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout="sbh3d" if fused_qkv_params else "sb3hd",
qkv_layout=(
qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd")
),
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends

# Skip if only unfused backend is supported
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
# Skip if qkv_format = thd and "padding" not in attn_mask_type
if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD requires padding mask.")

# UnfusedDotProductAttention backend
if unfused_attn_supported:
Expand Down Expand Up @@ -1194,6 +1201,8 @@ def test_transformer_layer(
torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_transformer_layer]: fused attn vs flash attn")
print("fused min/max", fused_attn_fwd.min().item(), fused_attn_fwd.max().item())
print("flash min/max", flash_attn_fwd.min().item(), flash_attn_fwd.max().item())
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)

Expand Down Expand Up @@ -1264,48 +1273,140 @@ def _run_transformer_layer(
_attention_backends["backend_selection_requires_update"] = True

# Create input tensor
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
if qkv_format == "sbhd":
inp = torch.randn(
config.max_seqlen_q,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.max_seqlen_kv,
config.batch_size,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
if qkv_format == "bshd":
inp = inp.transpose(0, 1)
inp = torch.randn(
config.batch_size,
config.max_seqlen_q,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
config.batch_size,
config.max_seqlen_kv,
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)

# Create seqlens
if "padding" in config.attn_mask_type:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
if "padding" in config.attn_mask_type or qkv_format == "thd":
if config.attn_type == "self":
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
seqlens_kv = seqlens_q
if config.attn_type == "cross":
if config.max_seqlen_q > 1:
seqlens_q = torch.randint(
1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda")
seqlens_kv = torch.randint(
1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
)
else:
seqlens_q = torch.full(
[config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
)
seqlens_kv = torch.full(
[config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
)
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
if qkv_format == "thd":
inp = torch.randn(
cu_seqlens_q[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)
inp_enc = torch.randn(
cu_seqlens_kv[-1],
config.hidden_size,
dtype=dtype,
device="cuda",
requires_grad=True,
)

# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
attention_mask_q = torch.cat(
[
attention_mask_q,
torch.Tensor(
[False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
)
.to(torch.bool)
.unsqueeze(0)
.unsqueeze(0)
.unsqueeze(0),
],
dim=0,
)
attention_mask = attention_mask_q.to(device="cuda")
# if "padding" in config.attn_mask_type:
# if config.attn_type == "self":
# attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
# for i in range(config.batch_size):
# attention_mask_q = torch.cat(
# [
# attention_mask_q,
# torch.Tensor(
# [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
# )
# .to(dtype=torch.bool)
# .unsqueeze(0)
# .unsqueeze(0)
# .unsqueeze(0),
# ],
# dim=0,
# )
# attention_mask = attention_mask_q.to(device="cuda")
# if config.attn_type == "cross":
# attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
# attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
# for i in range(config.batch_size):
# attention_mask_q = torch.cat(
# [
# attention_mask_q,
# torch.Tensor(
# [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
# )
# .to(dtype=torch.bool)
# .unsqueeze(0)
# .unsqueeze(0)
# .unsqueeze(0),
# ],
# dim=0,
# )
# attention_mask_kv = torch.cat(
# [
# attention_mask_kv,
# torch.Tensor(
# [False] * seqlens_kv[i]
# + [True] * (config.max_seqlen_kv - seqlens_kv[i])
# )
# .to(dtype=torch.bool)
# .unsqueeze(0)
# .unsqueeze(0)
# .unsqueeze(0),
# ],
# dim=0,
# )
# attention_mask = (
# attention_mask_q.to(device="cuda"),
# attention_mask_kv.to(device="cuda"),
# )

sigma = 0.02
init_method = init_method_normal(sigma)
Expand Down Expand Up @@ -1357,7 +1458,7 @@ def _run_transformer_layer(
sequence_parallel=False,
apply_residual_connection_post_layernorm=False,
output_layernorm=False,
layer_type="encoder",
layer_type="encoder" if config.attn_type == "self" else "decoder",
drop_path_rate=drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=fused_qkv_params,
Expand All @@ -1376,13 +1477,20 @@ def _run_transformer_layer(
# Run a forward and backward pass
out = block(
inp,
attention_mask=attention_mask,
attention_mask=None, # attention_mask_q,
self_attn_mask_type=config.attn_mask_type,
encoder_output=inp_enc if config.attn_type == "cross" else None,
enc_dec_attn_mask=None, # attention_mask if config.attn_type == "cross" else None,
enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
alibi_slopes=alibi_slopes,
max_seqlen_q=config.max_seqlen_q,
max_seqlen_kv=config.max_seqlen_kv,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
)
loss = out.sum()
loss.backward()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3484,7 +3484,64 @@ def attn_forward_func_with_cp(
use_flash_attn_3=False,
) -> torch.Tensor:
"""
Attention implementation with context parallelism.
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context
LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes
the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s
and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by
(cp_size * 2). It also requires tokens to be re-ordered before entering this function.

For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example
use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position
in their corresponding sequence.

GPU0 | GPU1 GPU0 | GPU1
seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8
---------------------------|----------------- ---------------------------|------------------
0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1,
0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1,
5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1,
---------------------------|----------------- ---------------------------|------------------
6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0,
11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1,

For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different
lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of
every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of
batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and
cp_size = 2.

GPU0 | GPU1 GPU0 | GPU1
seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1
seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2
---------------------------|----------------- ---------------------------|------------------
0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0,
0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0,
0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2,
---------------------------|----------------- ---------------------------|------------------
0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0,
1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2,

When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks,
cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for
all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_
in Megatron-LM.

"""

if cp_comm_type == "a2a+p2p":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,7 @@ def forward(
), "Keys and values must have the same batch size, sequence length and number of heads!"
num_attention_heads = query_layer.shape[-2]
num_gqa_groups = key_layer.shape[-2]
print("xxxxxxxxxx", [x.shape for x in [query_layer, key_layer, value_layer]])
assert (
query_layer.shape[-1] == key_layer.shape[-1]
), "Queries and keys must have the same head dimension!"
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/pytorch/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,19 @@ def forward(
)
for x in (key_layer, value_layer)
)
print("-----------cros", [x.shape for x in [key_layer, value_layer]])

if self.qkv_format == "thd":
key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
else:
# key, value: -> [sq, b, ng, hn]
key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)

# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
Expand Down
Loading