@@ -3490,13 +3490,54 @@ def attn_forward_func_with_cp(
3490
3490
3491
3491
Context parallelism distributes chunks of the sequence onto different GPUs. To help with
3492
3492
load balancing, users are expected to reorder their tokens before entering this function.
3493
- For example, given cp_size = 2, we divide each sequence in a batch into 4 chunks, and
3494
- distribute chunk 0 and chunk 3 onto GPU 0, and chunk 1 and chunk 2 onto GPU 1. This requires
3495
- sequence lengths to be divisible by (cp_size * 2), and if not, sequences need to be padded to
3496
- meet this requirement. When all transformer layers use the same context parallelism configuration,
3497
- token reordering can happen in the dataloader, i.e. only once for all the layers. An example of
3498
- the reordering is in Megatron-LM (see `get_batch_on_this_cp_rank
3499
- <https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_).
3493
+ This is required for all `attn_mask_type`s and all `qkv_format`s as of the current implementation.
3494
+ Also, each sequence is required to have seq_len % (cp_size * 2) == 0, and if not so, be padded so.
3495
+ When all transformer layers use the same context parallelism configuration, the reordering
3496
+ can take place in the dataloader, i.e. only once for all layers. An example code of the reordering is
3497
+ `get_batch_on_this_cp_rank
3498
+ <https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_
3499
+ in Megatron-LM.
3500
+
3501
+ For qkv_format = {'bshd', 'sbhd'}, sequences are of equal length in the batch, and for an example of
3502
+ s = 12 and cp_size = 2, the reordering transforms the token matrix as follows.
3503
+
3504
+ GPU0 | GPU1 GPU0 | GPU1
3505
+ seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8
3506
+ ---------------------------|----------------- ---------------------------|------------------
3507
+ 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3508
+ G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3509
+ P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3510
+ U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1,
3511
+ 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 --> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1,
3512
+ 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1,
3513
+ ---------------------------|----------------- ---------------------------|------------------
3514
+ 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
3515
+ G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
3516
+ P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
3517
+ U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3518
+ 1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0,
3519
+ 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1,
3520
+
3521
+ For qkv_format = 'thd', sequences may be of different lengths, and for a batch of 2, with seq_ids = [0, 1],
3522
+ seq_lengths = [8, 4], cp_size = 2, the reordering looks like:
3523
+
3524
+ GPU0 | GPU1 GPU0 | GPU1
3525
+ seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1
3526
+ seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2
3527
+ ---------------------------|----------------- ---------------------------|------------------
3528
+ 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3529
+ G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
3530
+ P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3531
+ U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0,
3532
+ 0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 --> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0,
3533
+ 0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2,
3534
+ ---------------------------|----------------- ---------------------------|------------------
3535
+ 0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
3536
+ G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
3537
+ P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
3538
+ U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
3539
+ 1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0,
3540
+ 1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2,
3500
3541
3501
3542
"""
3502
3543
0 commit comments