diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py index 6ce8637bc7..4e95c97f37 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/fused_attn/test_fused_attn.py @@ -1107,9 +1107,11 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), + "te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), + "te_2_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), } @@ -1120,7 +1122,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("model_configs", [model_configs_te_layer]) @pytest.mark.parametrize("model", model_configs_te_layer.keys()) @pytest.mark.parametrize("ckpt_attn", [False]) -@pytest.mark.parametrize("qkv_format", ["sbhd"]) +@pytest.mark.parametrize("qkv_format", ["sbhd", "bshd", "thd"]) @pytest.mark.parametrize("fused_qkv_params", [False]) @pytest.mark.parametrize("RoPE", [False]) def test_transformer_layer( @@ -1137,13 +1139,18 @@ def test_transformer_layer( available_backends, _, fused_attn_backends = _get_attention_backends( config, qkv_dtype=dtype, - qkv_layout="sbh3d" if fused_qkv_params else "sb3hd", + qkv_layout=( + qkv_format.replace("hd", "h3d") if fused_qkv_params else qkv_format.replace("hd", "3hd") + ), ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") + # Skip if qkv_format = thd and "padding" not in attn_mask_type + if qkv_format == "thd" and "padding" not in config.attn_mask_type: + pytest.skip("THD requires padding mask.") # UnfusedDotProductAttention backend if unfused_attn_supported: @@ -1194,6 +1201,8 @@ def test_transformer_layer( torch.testing.assert_close(flash_attn_bwd, unfused_attn_bwd, **tols) if fused_attn_supported and flash_attn_supported: logging.info("[test_transformer_layer]: fused attn vs flash attn") + print("fused min/max", fused_attn_fwd.min().item(), fused_attn_fwd.max().item()) + print("flash min/max", flash_attn_fwd.min().item(), flash_attn_fwd.max().item()) torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols) @@ -1264,48 +1273,140 @@ def _run_transformer_layer( _attention_backends["backend_selection_requires_update"] = True # Create input tensor - inp = torch.randn( - config.max_seqlen_q, - config.batch_size, - config.hidden_size, - dtype=dtype, - device="cuda", - requires_grad=True, - ) - # In case the format to be tested is batch-first, need to transpose the - # input tensor. + if qkv_format == "sbhd": + inp = torch.randn( + config.max_seqlen_q, + config.batch_size, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_enc = torch.randn( + config.max_seqlen_kv, + config.batch_size, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) if qkv_format == "bshd": - inp = inp.transpose(0, 1) + inp = torch.randn( + config.batch_size, + config.max_seqlen_q, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_enc = torch.randn( + config.batch_size, + config.max_seqlen_kv, + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) # Create seqlens - if "padding" in config.attn_mask_type: - seqlens_q = torch.randint( - 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" - ) + if "padding" in config.attn_mask_type or qkv_format == "thd": + if config.attn_type == "self": + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + seqlens_kv = seqlens_q + if config.attn_type == "cross": + if config.max_seqlen_q > 1: + seqlens_q = torch.randint( + 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda" + ) + else: + seqlens_q = torch.ones([config.batch_size], dtype=torch.int32, device="cuda") + seqlens_kv = torch.randint( + 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda" + ) else: seqlens_q = torch.full( [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda" ) + seqlens_kv = torch.full( + [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda" + ) + cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda") + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0) + if qkv_format == "thd": + inp = torch.randn( + cu_seqlens_q[-1], + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) + inp_enc = torch.randn( + cu_seqlens_kv[-1], + config.hidden_size, + dtype=dtype, + device="cuda", + requires_grad=True, + ) # Create attention mask if padding attention_mask = None - if "padding" in config.attn_mask_type: - attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) - for i in range(config.batch_size): - attention_mask_q = torch.cat( - [ - attention_mask_q, - torch.Tensor( - [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) - ) - .to(torch.bool) - .unsqueeze(0) - .unsqueeze(0) - .unsqueeze(0), - ], - dim=0, - ) - attention_mask = attention_mask_q.to(device="cuda") + # if "padding" in config.attn_mask_type: + # if config.attn_type == "self": + # attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + # for i in range(config.batch_size): + # attention_mask_q = torch.cat( + # [ + # attention_mask_q, + # torch.Tensor( + # [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) + # ) + # .to(dtype=torch.bool) + # .unsqueeze(0) + # .unsqueeze(0) + # .unsqueeze(0), + # ], + # dim=0, + # ) + # attention_mask = attention_mask_q.to(device="cuda") + # if config.attn_type == "cross": + # attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) + # attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) + # for i in range(config.batch_size): + # attention_mask_q = torch.cat( + # [ + # attention_mask_q, + # torch.Tensor( + # [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i]) + # ) + # .to(dtype=torch.bool) + # .unsqueeze(0) + # .unsqueeze(0) + # .unsqueeze(0), + # ], + # dim=0, + # ) + # attention_mask_kv = torch.cat( + # [ + # attention_mask_kv, + # torch.Tensor( + # [False] * seqlens_kv[i] + # + [True] * (config.max_seqlen_kv - seqlens_kv[i]) + # ) + # .to(dtype=torch.bool) + # .unsqueeze(0) + # .unsqueeze(0) + # .unsqueeze(0), + # ], + # dim=0, + # ) + # attention_mask = ( + # attention_mask_q.to(device="cuda"), + # attention_mask_kv.to(device="cuda"), + # ) sigma = 0.02 init_method = init_method_normal(sigma) @@ -1357,7 +1458,7 @@ def _run_transformer_layer( sequence_parallel=False, apply_residual_connection_post_layernorm=False, output_layernorm=False, - layer_type="encoder", + layer_type="encoder" if config.attn_type == "self" else "decoder", drop_path_rate=drop_path_rates[layer_number - 1], set_parallel_mode=True, fuse_qkv_params=fused_qkv_params, @@ -1376,13 +1477,20 @@ def _run_transformer_layer( # Run a forward and backward pass out = block( inp, - attention_mask=attention_mask, + attention_mask=None, # attention_mask_q, self_attn_mask_type=config.attn_mask_type, + encoder_output=inp_enc if config.attn_type == "cross" else None, + enc_dec_attn_mask=None, # attention_mask if config.attn_type == "cross" else None, + enc_dec_attn_mask_type=config.attn_mask_type if config.attn_type == "cross" else None, checkpoint_core_attention=False, rotary_pos_emb=rotary_pos_emb, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, ) loss = out.sum() loss.backward() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index cdff0de2df..f9a5d02496 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -3484,7 +3484,64 @@ def attn_forward_func_with_cp( use_flash_attn_3=False, ) -> torch.Tensor: """ - Attention implementation with context parallelism. + Attention implementation with context parallelism (CP). CP partitions tensors along the sequence + dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context + LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes + the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s + and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by + (cp_size * 2). It also requires tokens to be re-ordered before entering this function. + + For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example + use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position + in their corresponding sequence. + + GPU0 | GPU1 GPU0 | GPU1 + seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8 + ---------------------------|----------------- ---------------------------|------------------ + 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1, + 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1, + 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, + ---------------------------|----------------- ---------------------------|------------------ + 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0, + 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1, + + For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different + lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of + every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of + batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and + cp_size = 2. + + GPU0 | GPU1 GPU0 | GPU1 + seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1 + seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2 + ---------------------------|----------------- ---------------------------|------------------ + 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0, + P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0, + 0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0, + 0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2, + ---------------------------|----------------- ---------------------------|------------------ + 0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0, + G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0, + P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0, + U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0, + 1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0, + 1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2, + + When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks, + cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for + all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank + `_ + in Megatron-LM. + """ if cp_comm_type == "a2a+p2p": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 7d50b9fa54..4c1ad277bc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -675,6 +675,7 @@ def forward( ), "Keys and values must have the same batch size, sequence length and number of heads!" num_attention_heads = query_layer.shape[-2] num_gqa_groups = key_layer.shape[-2] + print("xxxxxxxxxx", [x.shape for x in [query_layer, key_layer, value_layer]]) assert ( query_layer.shape[-1] == key_layer.shape[-1] ), "Queries and keys must have the same head dimension!" diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f018465dc1..858783eb6d 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -713,6 +713,19 @@ def forward( ) for x in (key_layer, value_layer) ) + print("-----------cros", [x.shape for x in [key_layer, value_layer]]) + + if self.qkv_format == "thd": + key_layer, value_layer = ( + x.reshape(x.size(0), -1, self.hidden_size_per_attention_head) + for x in (key_layer, value_layer) + ) + else: + # key, value: -> [sq, b, ng, hn] + key_layer, value_layer = ( + x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head) + for x in (key_layer, value_layer) + ) # Attention head [sq, b, h] --> [sq, b, hp] if self.input_layernorm: diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1f5e6a3ee1..a39beb2a57 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module): The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the forward pass. - attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd' + attn_input_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' This controls whether the dimensions of the - intermediate hidden states is 'batch first' ('bshd') or - 'sequence first' ('sbhd'). `s` stands for the sequence - length, `b` batch size, `h` the number of heads, `d` - head size. Note that these formats are very closely + intermediate hidden states is 'sequence first' ('sbhd'), 'batch first' ('bshd'), + or 'token first' ('thd'). `s` stands for the sequence length, `b` batch size, + `t` the total number of tokens, `h` the number of heads, `d` head size. + Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. name: str, default = `None` @@ -678,7 +678,9 @@ def forward( if ( "padding" in self_attn_mask_type or self_attn_mask_type == "arbitrary" ) and attention_mask is not None: - assert attention_mask.dtype == torch.bool, "Attention mask must be a boolean tensor" + assert all( + attention_mask[i].dtype == torch.bool for i in range(len(attention_mask)) + ), "Attention mask must be a boolean tensor or a list/tuple of two boolean tensors" if ( "padding" in enc_dec_attn_mask_type or enc_dec_attn_mask_type == "arbitrary" ) and enc_dec_attn_mask is not None: @@ -707,9 +709,9 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_kv=cu_seqlens_q, max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, + max_seqlen_kv=max_seqlen_q, fast_zero_fill=fast_zero_fill, pad_between_seqs=pad_between_seqs, ) @@ -733,12 +735,19 @@ def forward( attn_mask_type=enc_dec_attn_mask_type, window_size=enc_dec_window_size, encoder_output=encoder_output, + inference_params=inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, + rotary_pos_emb=rotary_pos_emb, core_attention_bias_type=core_attention_bias_type, core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, fast_zero_fill=fast_zero_fill, + pad_between_seqs=pad_between_seqs, ) if self.apply_residual_connection_post_layernorm: attention_output, attention_bias, residual = inter_attention_outputs