Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
dianaml0 committed Jan 4, 2023
1 parent 2f7e693 commit 135972e
Showing 1 changed file with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def forward(
op=xf_op[0],
)
.transpose(0, 1)
.reshape(seq_len, bsz, num_heads*head_dim)
.reshape(seq_len, bsz, num_heads * head_dim)
)
# TODO: Reshape q/k/v back to original?
else:
Expand Down Expand Up @@ -413,7 +413,11 @@ def backward(ctx, grad_output):
op=xf_op[0],
)
out = attn
attn = attn.transpose(0, 1).reshape(seq_len, bsz, num_heads*head_dim).contiguous()
attn = (
attn.transpose(0, 1)
.reshape(seq_len, bsz, num_heads * head_dim)
.contiguous()
)
else:
attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha(
q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype
Expand Down

0 comments on commit 135972e

Please sign in to comment.