diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7586cc1bcb..c8ca157c28 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -5725,13 +5725,13 @@ def forward( seqlens_q = cu_seqlens_q_padded[1:] - cu_seqlens_q_padded[:-1] else: seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] - max_seqlen_q = pow(2, math.ceil(math.log2(seqlens_q.max().item()))) + max_seqlen_q = int((seqlens_q.max().item() + 63) // 64 * 64) if max_seqlen_kv is None: if cu_seqlens_kv_padded is not None: seqlens_kv = cu_seqlens_kv_padded[1:] - cu_seqlens_kv_padded[:-1] else: seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] - max_seqlen_kv = pow(2, math.ceil(math.log2(seqlens_kv.max().item()))) + max_seqlen_kv = int((seqlens_kv.max().item() + 63) // 64 * 64) batch_size = len(cu_seqlens_q) - 1 cp_size = 1 if self.cp_group is None else get_distributed_world_size(self.cp_group)