Skip to content

Commit 0d8fb83

Browse files
committed
add modulo cpx2 requirement
Signed-off-by: Charlene Yang <[email protected]>
1 parent bbccfe3 commit 0d8fb83

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3490,10 +3490,12 @@ def attn_forward_func_with_cp(
34903490
34913491
Context parallelism distributes chunks of the sequence onto different GPUs. To help with
34923492
load balancing, users are expected to reorder their tokens before entering this function.
3493-
For example, given cp_size = 2, we divide each sequence into 4 chunks, and distribute chunk 0
3494-
and chunk 3 onto GPU 0, and chunk 1 and chunk 2 onto GPU 1. If all transformer layers use
3495-
the same context parallel configuration, this reordering can happen only once, i.e. before
3496-
the first layer. An example of the reordering is in Megatron-LM (please see `get_batch_on_this_cp_rank
3493+
For example, given cp_size = 2, we divide each sequence in a batch into 4 chunks, and
3494+
distribute chunk 0 and chunk 3 onto GPU 0, and chunk 1 and chunk 2 onto GPU 1. This requires
3495+
sequence lengths to be divisible by (cp_size * 2), and if not, sequences need to be padded to
3496+
meet this requirement. When all transformer layers use the same context parallelism configuration,
3497+
token reordering can happen in the dataloader, i.e. only once for all the layers. An example of
3498+
the reordering is in Megatron-LM (see `get_batch_on_this_cp_rank
34973499
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_).
34983500
34993501
"""

0 commit comments

Comments
 (0)