Skip to content

Commit ea63d61

Browse files
cyanguwaKshitijLakhani
authored andcommitted
[PyTorch] Add docstring for CP load balancing (#1802)
add docstring for CP Signed-off-by: Charlene Yang <[email protected]>
1 parent 7fe5d68 commit ea63d61

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3484,7 +3484,64 @@ def attn_forward_func_with_cp(
34843484
use_flash_attn_3=False,
34853485
) -> torch.Tensor:
34863486
"""
3487-
Attention implementation with context parallelism.
3487+
Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
3488+
dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context
3489+
LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes
3490+
the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s
3491+
and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by
3492+
(cp_size * 2). It also requires tokens to be re-ordered before entering this function.
3493+
3494+
For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example
3495+
use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position
3496+
in their corresponding sequence.
3497+
3498+
GPU0 | GPU1 GPU0 | GPU1
3499+
seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8
3500+
---------------------------|----------------- ---------------------------|------------------
3501+
0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3502+
G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3503+
P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3504+
U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1,
3505+
0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1,
3506+
5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1,
3507+
---------------------------|----------------- ---------------------------|------------------
3508+
6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
3509+
G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
3510+
P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
3511+
U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3512+
1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0,
3513+
11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1,
3514+
3515+
For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different
3516+
lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of
3517+
every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of
3518+
batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and
3519+
cp_size = 2.
3520+
3521+
GPU0 | GPU1 GPU0 | GPU1
3522+
seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1
3523+
seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2
3524+
---------------------------|----------------- ---------------------------|------------------
3525+
0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3526+
G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3527+
P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3528+
U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0,
3529+
0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0,
3530+
0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2,
3531+
---------------------------|----------------- ---------------------------|------------------
3532+
0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
3533+
G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
3534+
P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
3535+
U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3536+
1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0,
3537+
1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2,
3538+
3539+
When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks,
3540+
cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for
3541+
all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank
3542+
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_
3543+
in Megatron-LM.
3544+
34883545
"""
34893546

34903547
if cp_comm_type == "a2a+p2p":

0 commit comments

Comments
 (0)