Skip to content

Commit 52fcbc5

Browse files
committed
add example of token reordering
Signed-off-by: Charlene Yang <[email protected]>
1 parent fc4d718 commit 52fcbc5

File tree

1 file changed

+48
-7
lines changed

1 file changed

+48
-7
lines changed

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,13 +3490,54 @@ def attn_forward_func_with_cp(
34903490
34913491
Context parallelism distributes chunks of the sequence onto different GPUs. To help with
34923492
load balancing, users are expected to reorder their tokens before entering this function.
3493-
For example, given cp_size = 2, we divide each sequence in a batch into 4 chunks, and
3494-
distribute chunk 0 and chunk 3 onto GPU 0, and chunk 1 and chunk 2 onto GPU 1. This requires
3495-
sequence lengths to be divisible by (cp_size * 2), and if not, sequences need to be padded to
3496-
meet this requirement. When all transformer layers use the same context parallelism configuration,
3497-
token reordering can happen in the dataloader, i.e. only once for all the layers. An example of
3498-
the reordering is in Megatron-LM (see `get_batch_on_this_cp_rank
3499-
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_).
3493+
This is required for all `attn_mask_type`s and all `qkv_format`s as of the current implementation.
3494+
Also, each sequence is required to have seq_len % (cp_size * 2) == 0, and if not so, be padded so.
3495+
When all transformer layers use the same context parallelism configuration, the reordering
3496+
can take place in the dataloader, i.e. only once for all layers. An example code of the reordering is
3497+
`get_batch_on_this_cp_rank
3498+
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_
3499+
in Megatron-LM.
3500+
3501+
For qkv_format = {'bshd', 'sbhd'}, sequences are of equal length in the batch, and for an example of
3502+
s = 12 and cp_size = 2, the reordering transforms the token matrix as follows.
3503+
3504+
GPU0 | GPU1 GPU0 | GPU1
3505+
seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8
3506+
---------------------------|----------------- ---------------------------|------------------
3507+
0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3508+
G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3509+
P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3510+
U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1,
3511+
0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 --> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1,
3512+
5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1,
3513+
---------------------------|----------------- ---------------------------|------------------
3514+
6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
3515+
G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
3516+
P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
3517+
U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3518+
1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0,
3519+
11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1,
3520+
3521+
For qkv_format = 'thd', sequences may be of different lengths, and for a batch of 2, with seq_ids = [0, 1],
3522+
seq_lengths = [8, 4], cp_size = 2, the reordering looks like:
3523+
3524+
GPU0 | GPU1 GPU0 | GPU1
3525+
seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1
3526+
seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2
3527+
---------------------------|----------------- ---------------------------|------------------
3528+
0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3529+
G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3530+
P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3531+
U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0,
3532+
0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 --> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0,
3533+
0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2,
3534+
---------------------------|----------------- ---------------------------|------------------
3535+
0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
3536+
G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
3537+
P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
3538+
U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3539+
1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0,
3540+
1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2,
35003541
35013542
"""
35023543

0 commit comments

Comments
 (0)