Skip to content

Commit

Permalink
[PyTorch] Reduce the amount of roundup for max_seqlen in THD (NVIDIA#…
Browse files Browse the repository at this point in the history
…1079)

reduce the roundup of max_seqlen for THD to multiples of 64

Signed-off-by: Charlene Yang <[email protected]>
  • Loading branch information
cyanguwa authored Aug 6, 2024
1 parent 121ff62 commit 8833a8d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5725,13 +5725,13 @@ def forward(
seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1]
else:
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item())))
max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64)
if max_seqlen_kv is None:
if cu_seqlens_kv_padded is not None:
seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1]
else:
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item())))
max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64)
batch_size = len(cu_seqlens_q) - 1

cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)
Expand Down

0 comments on commit 8833a8d

Please sign in to comment.