diff --git a/tencentpretrain/encoders/transformer_encoder.py b/tencentpretrain/encoders/transformer_encoder.py index 6bcb3c3..f3dd653 100644 --- a/tencentpretrain/encoders/transformer_encoder.py +++ b/tencentpretrain/encoders/transformer_encoder.py @@ -21,6 +21,13 @@ def __init__(self, args): self.rotary_position_embedding = args.rotary_position_embedding self.has_residual_attention = args.has_residual_attention self.use_mp = args.use_mp + + if self.relative_position_embedding: + args.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num, + num_buckets=args.relative_attention_buckets_num) + elif self.rotary_position_embedding: + args.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) + if "deepspeed_checkpoint_activations" in args: self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num @@ -101,54 +108,35 @@ def forward(self, emb, seg): mask = (1.0 - mask) * -10000.0 hidden = emb - - if self.relative_position_embedding: - position_bias = self.relative_pos_emb(hidden, hidden) - else: - position_bias = None - - if self.rotary_position_embedding: - freqs_cis = self.freqs_cis[:seq_length].to(hidden.device) - else: - freqs_cis = None - - prev_attn = None + inputs = hidden, mask if self.deepspeed_checkpoint_activations: from deepspeed import checkpointing def custom(start, end): def custom_forward(*inputs): - x_, y_, position_bias_, freqs_cis_ = inputs for index in range(start, end): if self.parameter_sharing: - x_, y_ = self.transformer(x_, mask, position_bias=position_bias_, - has_residual_attention=self.has_residual_attention, - prev_attn=y_, freqs_cis=freqs_cis_) + inputs = self.transformer(*inputs) else: - x_, y_ = self.transformer[index](x_, mask, position_bias=position_bias_, - has_residual_attention=self.has_residual_attention, - prev_attn=y_, freqs_cis=freqs_cis_) - return x_, y_ + inputs = self.transformer[index](*inputs) + return inputs return custom_forward if self.use_mp: mpu.reset_checkpointed_activations_memory_buffer() l = 0 while l < self.layers_num: - hidden, prev_attn = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), - hidden, prev_attn, position_bias, freqs_cis) + inputs = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), inputs) l += self.deepspeed_checkpoint_layers_num else: for i in range(self.layers_num): if self.parameter_sharing: - hidden, prev_attn = self.transformer(hidden, mask, position_bias=position_bias, - has_residual_attention=self.has_residual_attention, - prev_attn=prev_attn, freqs_cis=freqs_cis) + inputs = self.transformer(inputs) else: - hidden, prev_attn = self.transformer[i](hidden, mask, position_bias=position_bias, - has_residual_attention=self.has_residual_attention, - prev_attn=prev_attn, freqs_cis=freqs_cis) + inputs = self.transformer[i](inputs) + + hidden = inputs[0] if self.layernorm_positioning == "pre": return self.layer_norm(hidden) diff --git a/tencentpretrain/layers/transformer.py b/tencentpretrain/layers/transformer.py index 08acfca..909908a 100755 --- a/tencentpretrain/layers/transformer.py +++ b/tencentpretrain/layers/transformer.py @@ -11,6 +11,13 @@ def __init__(self, args, layer_number=None): super(TransformerLayer, self).__init__() self.layernorm_positioning = args.layernorm_positioning + self.relative_position_embedding = args.relative_position_embedding + self.rotary_position_embedding = args.rotary_position_embedding + self.has_residual_attention = args.has_residual_attention + if self.relative_position_embedding: + self.relative_pos_emb = args.relative_pos_emb + if self.rotary_position_embedding: + self.freqs_cis = args.freqs_cis if hasattr(args, "attention_head_size"): attention_head_size = args.attention_head_size @@ -45,8 +52,8 @@ def __init__(self, args, layer_number=None): self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) - def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, - prev_attn=None, freqs_cis=None, alibi=None): + def forward(self, inputs): + """ Args: hidden: [batch_size x seq_length x emb_size] @@ -55,23 +62,44 @@ def forward(self, hidden, mask, position_bias=None, has_residual_attention=False Returns: output: [batch_size x seq_length x hidden_size] """ + if len(inputs)==2: + hidden, mask = inputs + prev_attn = None + else: + hidden, mask, prev_attn = inputs + + _, seq_length, _ = hidden.size() + + if self.relative_position_embedding: + position_bias = self.relative_pos_emb(hidden, hidden) + else: + position_bias = None + + if self.rotary_position_embedding: + freqs_cis = self.freqs_cis[:seq_length].to(hidden.device) + else: + freqs_cis = None if self.layernorm_positioning == "post": - inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention, - prev_attn, freqs_cis, alibi) + inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, self.has_residual_attention, + prev_attn, freqs_cis) inter = self.dropout_1(inter) inter = self.layer_norm_1(inter + hidden) output = self.dropout_2(self.feed_forward(inter)) output = self.layer_norm_2(output + inter) else: inter = self.layer_norm_1(hidden) - inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, - prev_attn, freqs_cis, alibi) + inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, self.has_residual_attention, + prev_attn, freqs_cis) inter = self.dropout_1(inter) hidden = hidden + inter output = self.layer_norm_2(hidden) output = self.dropout_2(self.feed_forward(output)) + hidden - return output, prev_attn_out + + if self.has_residual_attention: + return output, mask, prev_attn_out + else: + return output, mask class ParallelTransformerLayer(nn.Module): @@ -80,6 +108,13 @@ def __init__(self, args, layer_number=None): super(ParallelTransformerLayer, self).__init__() self.layernorm_positioning = args.layernorm_positioning + self.relative_position_embedding = args.relative_position_embedding + self.rotary_position_embedding = args.rotary_position_embedding + self.has_residual_attention = args.has_residual_attention + if self.relative_position_embedding: + self.relative_pos_emb = args.relative_pos_emb + if self.rotary_position_embedding: + self.freqs_cis = args.freqs_cis if hasattr(args, "attention_head_size"): attention_head_size = args.attention_head_size @@ -114,8 +149,7 @@ def __init__(self, args, layer_number=None): self.layer_norm_1 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) self.layer_norm_2 = str2layernorm[args.layernorm](args.hidden_size, eps=args.layernorm_eps) - def forward(self, hidden, mask, position_bias=None, has_residual_attention=False, - prev_attn=None, freqs_cis=None, alibi=None): + def forward(self, inputs): """ Args: @@ -126,22 +160,44 @@ def forward(self, hidden, mask, position_bias=None, has_residual_attention=False output: [batch_size x seq_length x hidden_size] """ + if len(inputs)==2: + hidden, mask = inputs + prev_attn = None + else: + hidden, mask, prev_attn = inputs + + _, seq_length, _ = hidden.size() + + if self.relative_position_embedding: + position_bias = self.relative_pos_emb(hidden, hidden) + else: + position_bias = None + + if self.rotary_position_embedding: + freqs_cis = self.freqs_cis[:seq_length].to(hidden.device) + else: + freqs_cis = None + if self.layernorm_positioning == "post": - inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, has_residual_attention, - prev_attn, freqs_cis, alibi) + inter, prev_attn_out = self.self_attn(hidden, hidden, hidden, mask, position_bias, self.has_residual_attention, + prev_attn, freqs_cis) inter = self.dropout_1(inter) inter = self.layer_norm_1(inter + hidden) output = self.dropout_2(self.feed_forward(inter)) output = self.layer_norm_2(output + inter) else: inter = self.layer_norm_1(hidden) - inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, has_residual_attention, - prev_attn, freqs_cis, alibi) + inter, prev_attn_out = self.self_attn(inter, inter, inter, mask, position_bias, self.has_residual_attention, + prev_attn, freqs_cis) inter = self.dropout_1(inter) hidden = hidden + inter output = self.layer_norm_2(hidden) output = self.dropout_2(self.feed_forward(output)) + hidden - return output, prev_attn_out + + if self.has_residual_attention: + return output, mask, prev_attn_out + else: + return output, mask class TransformerDecoderLayer(nn.Module): diff --git a/tencentpretrain/trainer.py b/tencentpretrain/trainer.py index 6f10cba..5ee00ac 100755 --- a/tencentpretrain/trainer.py +++ b/tencentpretrain/trainer.py @@ -23,7 +23,7 @@ def init_model(args): with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group() if args.use_mp else None, remote_device=None, config_dict_or_path=args.deepspeed_config, - enabled=args.enable_zero3 == 3, + enabled=args.enable_zero3 == True, mpu=mpu if args.use_mp else None ): model_for_training = build_model(args) if args.use_mp: