Skip to content

Commit

Permalink
Refactor transformer encoder (#111)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhou435 authored Nov 28, 2023
1 parent b27211e commit 1f28a18
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 43 deletions.
44 changes: 16 additions & 28 deletions tencentpretrain/encoders/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def __init__(self, args):
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
self.use_mp = args.use_mp

if self.relative_position_embedding:
args.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num,
num_buckets=args.relative_attention_buckets_num)
elif self.rotary_position_embedding:
args.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)

if "deepspeed_checkpoint_activations" in args:
self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations
self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num
Expand Down Expand Up @@ -101,54 +108,35 @@ def forward(self, emb, seg):
mask = (1.0 - mask) * -10000.0

hidden = emb

if self.relative_position_embedding:
position_bias = self.relative_pos_emb(hidden, hidden)
else:
position_bias = None

if self.rotary_position_embedding:
freqs_cis = self.freqs_cis[:seq_length].to(hidden.device)
else:
freqs_cis = None

prev_attn = None
inputs = hidden, mask

if self.deepspeed_checkpoint_activations:
from deepspeed import checkpointing

def custom(start, end):
def custom_forward(*inputs):
x_, y_, position_bias_, freqs_cis_ = inputs
for index in range(start, end):
if self.parameter_sharing:
x_, y_ = self.transformer(x_, mask, position_bias=position_bias_,
has_residual_attention=self.has_residual_attention,
prev_attn=y_, freqs_cis=freqs_cis_)
inputs = self.transformer(*inputs)
else:
x_, y_ = self.transformer[index](x_, mask, position_bias=position_bias_,
has_residual_attention=self.has_residual_attention,
prev_attn=y_, freqs_cis=freqs_cis_)
return x_, y_
inputs = self.transformer[index](*inputs)
return inputs

return custom_forward
if self.use_mp:
mpu.reset_checkpointed_activations_memory_buffer()
l = 0
while l < self.layers_num:
hidden, prev_attn = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num),
hidden, prev_attn, position_bias, freqs_cis)
inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs)
l += self.deepspeed_checkpoint_layers_num
else:
for i in range(self.layers_num):
if self.parameter_sharing:
hidden, prev_attn = self.transformer(hidden, mask, position_bias=position_bias,
has_residual_attention=self.has_residual_attention,
prev_attn=prev_attn, freqs_cis=freqs_cis)
inputs = self.transformer(inputs)
else:
hidden, prev_attn = self.transformer[i](hidden, mask, position_bias=position_bias,
has_residual_attention=self.has_residual_attention,
prev_attn=prev_attn, freqs_cis=freqs_cis)
inputs = self.transformer[i](inputs)

hidden = inputs[0]

if self.layernorm_positioning == "pre":
return self.layer_norm(hidden)
Expand Down
84 changes: 70 additions & 14 deletions tencentpretrain/layers/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@ def __init__(self, args, layer_number=None):
super(TransformerLayer, self).__init__()

self.layernorm_positioning = args.layernorm_positioning
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
if self.relative_position_embedding:
self.relative_pos_emb = args.relative_pos_emb
if self.rotary_position_embedding:
self.freqs_cis = args.freqs_cis

if hasattr(args, "attention_head_size"):
attention_head_size = args.attention_head_size
Expand Down Expand Up @@ -45,8 +52,8 @@ def __init__(self, args, layer_number=None):
self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)

def forward(self, hidden, mask, position_bias=None, has_residual_attention=False,
prev_attn=None, freqs_cis=None, alibi=None):
def forward(self, inputs):

"""
Args:
hidden: [batch_size x seq_length x emb_size]
Expand All @@ -55,23 +62,44 @@ def forward(self, hidden, mask, position_bias=None, has_residual_attention=False
Returns:
output: [batch_size x seq_length x hidden_size]
"""
if len(inputs)==2:
hidden, mask = inputs
prev_attn = None
else:
hidden, mask, prev_attn = inputs

_, seq_length, _ = hidden.size()

if self.relative_position_embedding:
position_bias = self.relative_pos_emb(hidden, hidden)
else:
position_bias = None

if self.rotary_position_embedding:
freqs_cis = self.freqs_cis[:seq_length].to(hidden.device)
else:
freqs_cis = None

if self.layernorm_positioning == "post":
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention,
prev_attn, freqs_cis, alibi)
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, self.has_residual_attention,
prev_attn, freqs_cis)
inter = self.dropout_1(inter)
inter = self.layer_norm_1(inter + hidden)
output = self.dropout_2(self.feed_forward(inter))
output = self.layer_norm_2(output + inter)
else:
inter = self.layer_norm_1(hidden)
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention,
prev_attn, freqs_cis, alibi)
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, self.has_residual_attention,
prev_attn, freqs_cis)
inter = self.dropout_1(inter)
hidden = hidden + inter
output = self.layer_norm_2(hidden)
output = self.dropout_2(self.feed_forward(output)) + hidden
return output, prev_attn_out

if self.has_residual_attention:
return output, mask, prev_attn_out
else:
return output, mask


class ParallelTransformerLayer(nn.Module):
Expand All @@ -80,6 +108,13 @@ def __init__(self, args, layer_number=None):
super(ParallelTransformerLayer, self).__init__()

self.layernorm_positioning = args.layernorm_positioning
self.relative_position_embedding = args.relative_position_embedding
self.rotary_position_embedding = args.rotary_position_embedding
self.has_residual_attention = args.has_residual_attention
if self.relative_position_embedding:
self.relative_pos_emb = args.relative_pos_emb
if self.rotary_position_embedding:
self.freqs_cis = args.freqs_cis

if hasattr(args, "attention_head_size"):
attention_head_size = args.attention_head_size
Expand Down Expand Up @@ -114,8 +149,7 @@ def __init__(self, args, layer_number=None):
self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)
self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps)

def forward(self, hidden, mask, position_bias=None, has_residual_attention=False,
prev_attn=None, freqs_cis=None, alibi=None):
def forward(self, inputs):

"""
Args:
Expand All @@ -126,22 +160,44 @@ def forward(self, hidden, mask, position_bias=None, has_residual_attention=False
output: [batch_size x seq_length x hidden_size]
"""

if len(inputs)==2:
hidden, mask = inputs
prev_attn = None
else:
hidden, mask, prev_attn = inputs

_, seq_length, _ = hidden.size()

if self.relative_position_embedding:
position_bias = self.relative_pos_emb(hidden, hidden)
else:
position_bias = None

if self.rotary_position_embedding:
freqs_cis = self.freqs_cis[:seq_length].to(hidden.device)
else:
freqs_cis = None

if self.layernorm_positioning == "post":
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention,
prev_attn, freqs_cis, alibi)
inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, self.has_residual_attention,
prev_attn, freqs_cis)
inter = self.dropout_1(inter)
inter = self.layer_norm_1(inter + hidden)
output = self.dropout_2(self.feed_forward(inter))
output = self.layer_norm_2(output + inter)
else:
inter = self.layer_norm_1(hidden)
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention,
prev_attn, freqs_cis, alibi)
inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, self.has_residual_attention,
prev_attn, freqs_cis)
inter = self.dropout_1(inter)
hidden = hidden + inter
output = self.layer_norm_2(hidden)
output = self.dropout_2(self.feed_forward(output)) + hidden
return output, prev_attn_out

if self.has_residual_attention:
return output, mask, prev_attn_out
else:
return output, mask


class TransformerDecoderLayer(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion tencentpretrain/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def init_model(args):
with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group() if args.use_mp else None,
remote_device=None,
config_dict_or_path=args.deepspeed_config,
enabled=args.enable_zero3 == 3,
enabled=args.enable_zero3 == True,
mpu=mpu if args.use_mp else None ):
model_for_training = build_model(args)
if args.use_mp:
Expand Down

0 comments on commit 1f28a18

Please sign in to comment.