From cb3a4eef30c4f5d34633ece3bf1503811e90d7f4 Mon Sep 17 00:00:00 2001 From: Diana Liskovich Date: Fri, 23 Dec 2022 14:07:14 -0800 Subject: [PATCH] formatting --- .../modules/sequence_parallel_transformer_layer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py index 0823b0936..d29b7ebe5 100644 --- a/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py +++ b/metaseq/model_parallel/modules/sequence_parallel_transformer_layer.py @@ -195,7 +195,7 @@ def forward( op=xf_op[0], ) .transpose(0, 1) - .reshape(seq_len, bsz, num_heads*head_dim) + .reshape(seq_len, bsz, num_heads * head_dim) ) # TODO: Reshape q/k/v back to original? else: @@ -413,7 +413,11 @@ def backward(ctx, grad_output): op=xf_op[0], ) out = attn - attn = attn.transpose(0, 1).reshape(seq_len, bsz, num_heads*head_dim).contiguous() + attn = ( + attn.transpose(0, 1) + .reshape(seq_len, bsz, num_heads * head_dim) + .contiguous() + ) else: attn, attn_probs = SequeuceParallelTransformerBlock.forward_mha( q, k, v, bsz, seq_len, head_dim, embed_dim_per_partition, dtype