From a1fc307db5d920e138166716c2798af38f85acf5 Mon Sep 17 00:00:00 2001 From: xiezipeng-ML Date: Fri, 28 Oct 2022 09:47:48 +0000 Subject: [PATCH 1/3] use fuse multi head att --- projects/T5/configs/mt5_pretrain.py | 3 ++- projects/T5/configs/t5_model_config.py | 1 + projects/T5/models/attention.py | 30 ++++++++++++++++++------- projects/T5/models/t5_model.py | 4 ++++ projects/T5/models/transformer_layer.py | 6 +++++ 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/projects/T5/configs/mt5_pretrain.py b/projects/T5/configs/mt5_pretrain.py index bd9fd3bd8..e3dcf3bf7 100644 --- a/projects/T5/configs/mt5_pretrain.py +++ b/projects/T5/configs/mt5_pretrain.py @@ -18,7 +18,7 @@ train_data_path = "projects/T5/data/training_data/part_0" pretrained_model_path = None -micro_batch_size = 64 +micro_batch_size = 16 optim["lr"] = 1e-4 # dataloader @@ -54,6 +54,7 @@ model.cfg.embedding_dropout_prob = 0.0 model.cfg.layernorm_eps = 1e-6 model.cfg.model_type = "mt5" +model.cfg.scale_mask_softmax_fusion = True model.cfg.pretrained_model_path = pretrained_model_path train.update( diff --git a/projects/T5/configs/t5_model_config.py b/projects/T5/configs/t5_model_config.py index 50523f756..9e4cb9d96 100644 --- a/projects/T5/configs/t5_model_config.py +++ b/projects/T5/configs/t5_model_config.py @@ -14,6 +14,7 @@ embedding_dropout_prob=0.1, initializer_range=0.02, layernorm_eps=1e-5, + scale_mask_softmax_fusion=True, amp_enabled=False, model_type="t5", ) diff --git a/projects/T5/models/attention.py b/projects/T5/models/attention.py index a825f681a..9a8d8b1b5 100644 --- a/projects/T5/models/attention.py +++ b/projects/T5/models/attention.py @@ -54,6 +54,7 @@ def __init__( output_dropout_prob=0.0, init_method=nn.init.xavier_normal_, output_layer_init_method=None, + scale_mask_softmax_fusion=True, *, layer_idx=0, has_relative_attention_bias=False, @@ -65,6 +66,7 @@ def __init__( self.has_relative_attention_bias = has_relative_attention_bias self.is_decoder = is_decoder self.attention_dropout_prob = attention_dropout_prob + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion if output_layer_init_method is None: output_layer_init_method = init_method @@ -230,14 +232,26 @@ def forward( # [S(0), S(1)] x [S(0), B] = [S(0), S(1)] if attention_mask is not None: - attention_scores = flow.mul(attention_scores, attention_mask) - attention_scores = attention_scores - 10000.0 * (1 - attention_mask) - # TODO(xingyu.liao): graph will occur `where_scalar` errors - # when using `masked_fill` - # attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0) - attention_weights = flow.softmax(attention_scores, dim=-1) - # [bsz, num_heads, tgt_len, src_len] - attention_weights = self.dropout(attention_weights) + if self.scale_mask_softmax_fusion: + attention_mask = ( + attention_mask.expand_as(attention_scores) if use_cache else attention_mask + ) + attention_weights = flow._C.fused_scale_mask_softmax_dropout( + attention_scores, + attention_mask, + fill_value=-10000.0, + scale=1, + p=self.attention_dropout_prob, + )[0] + else: + attention_scores = flow.mul(attention_scores, attention_mask) + attention_scores = attention_scores - 10000.0 * (1 - attention_mask) + # TODO(xingyu.liao): graph will occur `where_scalar` errors + # when using `masked_fill` + # attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0) + attention_weights = flow.softmax(attention_scores, dim=-1) + # [bsz, num_heads, tgt_len, src_len] + attention_weights = self.dropout(attention_weights) else: attention_weights = flow.softmax(attention_scores, dim=-1) # [bsz, num_heads, tgt_len, src_len] diff --git a/projects/T5/models/t5_model.py b/projects/T5/models/t5_model.py index f9411f0c9..97e241eba 100644 --- a/projects/T5/models/t5_model.py +++ b/projects/T5/models/t5_model.py @@ -41,6 +41,7 @@ def __init__( hidden_dropout_prob, attention_probs_dropout_prob, relative_attention_num_buckets, + scale_mask_softmax_fusion=True, initializer_range=0.02, layernorm_eps=1e-12, amp_enabled=False, @@ -73,6 +74,7 @@ def __init__( layernorm_epsilon=layernorm_eps, init_method=init_method, output_layer_init_method=scaled_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, layer_idx=i, model_type=model_type, has_relative_attention_bias=bool(i == 0), @@ -105,6 +107,7 @@ def __init__( layernorm_epsilon=layernorm_eps, init_method=init_method, output_layer_init_method=scaled_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, layer_idx=i, model_type=model_type, has_relative_attention_bias=bool(i - hidden_layers == 0), @@ -150,6 +153,7 @@ def from_config(cls, cfg): "layernorm_eps": cfg.layernorm_eps, "amp_enabled": cfg.amp_enabled, "model_type": cfg.model_type, + "scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion, } def forward( diff --git a/projects/T5/models/transformer_layer.py b/projects/T5/models/transformer_layer.py index c23cb903d..fc0c2dc0d 100644 --- a/projects/T5/models/transformer_layer.py +++ b/projects/T5/models/transformer_layer.py @@ -58,6 +58,7 @@ def __init__( layernorm_epsilon=1e-5, init_method=nn.init.xavier_normal_, output_layer_init_method=None, + scale_mask_softmax_fusion=True, *, layer_idx=0, model_type="t5", @@ -73,6 +74,7 @@ def __init__( self.layernorm_epsilon = layernorm_epsilon self.layer_idx = layer_idx self.is_decoder = is_decoder + self.scale_mask_softmax_fusion = scale_mask_softmax_fusion self.init_method = init_method if output_layer_init_method is None: @@ -89,6 +91,7 @@ def __init__( is_cross_attention=False, relative_attention_num_buckets=relative_attention_num_buckets, has_relative_attention_bias=has_relative_attention_bias, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, is_decoder=self.is_decoder, ) self.post_attention_layernorm = LayerNorm( @@ -99,6 +102,7 @@ def __init__( self.cross_attention = self.build_attention( is_cross_attention=True, relative_attention_num_buckets=relative_attention_num_buckets, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, is_decoder=self.is_decoder, ) self.post_cross_attention_layernorm = LayerNorm( @@ -234,6 +238,7 @@ def build_attention( is_cross_attention=False, relative_attention_num_buckets=None, has_relative_attention_bias=False, + scale_mask_softmax_fusion=True, is_decoder=False, ): return MultiheadAttention( @@ -246,6 +251,7 @@ def build_attention( output_dropout_prob=self.output_dropout_prob, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, + scale_mask_softmax_fusion=scale_mask_softmax_fusion, layer_idx=self.layer_idx, has_relative_attention_bias=has_relative_attention_bias, is_decoder=is_decoder, From 5cf55a6d8f5910252c120ac3697467e4952c51ca Mon Sep 17 00:00:00 2001 From: xiezipeng-ML Date: Fri, 28 Oct 2022 12:42:23 +0000 Subject: [PATCH 2/3] use batch size = 4, acc step = 8, amp, open Checkpointing --- projects/T5/configs/mt5_pretrain.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/projects/T5/configs/mt5_pretrain.py b/projects/T5/configs/mt5_pretrain.py index e3dcf3bf7..93fcf3269 100644 --- a/projects/T5/configs/mt5_pretrain.py +++ b/projects/T5/configs/mt5_pretrain.py @@ -18,7 +18,7 @@ train_data_path = "projects/T5/data/training_data/part_0" pretrained_model_path = None -micro_batch_size = 16 +micro_batch_size = 4 optim["lr"] = 1e-4 # dataloader @@ -64,7 +64,7 @@ train_epoch=1, train_iter=24000, log_period=10, - amp=dict(enabled=False), + amp=dict(enabled=True), warmup_ratio=1 / 24, # checkpointer=dict(period=10, max_to_keep=20), dist=dict( @@ -90,3 +90,5 @@ train.zero_optimization.enabled = True train.zero_optimization.stage = 2 +train.activation_checkpoint.enabled = True +train.num_accumulation_steps = 8 From 1c2ada8534c598c046a88f2d68d27e90cbb51ebb Mon Sep 17 00:00:00 2001 From: xiezipeng-ML Date: Tue, 1 Nov 2022 14:40:40 +0000 Subject: [PATCH 3/3] add activation checkpoint --- projects/T5/models/t5_model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/projects/T5/models/t5_model.py b/projects/T5/models/t5_model.py index 97e241eba..f4ed64e43 100644 --- a/projects/T5/models/t5_model.py +++ b/projects/T5/models/t5_model.py @@ -315,3 +315,8 @@ def set_pipeline_stage_id(model): dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx), dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx), ) + + def set_activation_checkpoint(self): + for module_block in self.t5_model.modules(): + if isinstance(module_block.origin, TransformerLayer): + module_block.config.activation_checkpointing = True \ No newline at end of file