Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use_fuse_mask_softmax #412

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions projects/T5/configs/mt5_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
train_data_path = "projects/T5/data/training_data/part_0"
pretrained_model_path = None

micro_batch_size = 64
micro_batch_size = 4
optim["lr"] = 1e-4

# dataloader
Expand Down Expand Up @@ -54,6 +54,7 @@
model.cfg.embedding_dropout_prob = 0.0
model.cfg.layernorm_eps = 1e-6
model.cfg.model_type = "mt5"
model.cfg.scale_mask_softmax_fusion = True
model.cfg.pretrained_model_path = pretrained_model_path

train.update(
Expand All @@ -63,7 +64,7 @@
train_epoch=1,
train_iter=24000,
log_period=10,
amp=dict(enabled=False),
amp=dict(enabled=True),
warmup_ratio=1 / 24,
# checkpointer=dict(period=10, max_to_keep=20),
dist=dict(
Expand All @@ -89,3 +90,5 @@

train.zero_optimization.enabled = True
train.zero_optimization.stage = 2
train.activation_checkpoint.enabled = True
train.num_accumulation_steps = 8
1 change: 1 addition & 0 deletions projects/T5/configs/t5_model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
embedding_dropout_prob=0.1,
initializer_range=0.02,
layernorm_eps=1e-5,
scale_mask_softmax_fusion=True,
amp_enabled=False,
model_type="t5",
)
Expand Down
30 changes: 22 additions & 8 deletions projects/T5/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
output_dropout_prob=0.0,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
scale_mask_softmax_fusion=True,
*,
layer_idx=0,
has_relative_attention_bias=False,
Expand All @@ -65,6 +66,7 @@ def __init__(
self.has_relative_attention_bias = has_relative_attention_bias
self.is_decoder = is_decoder
self.attention_dropout_prob = attention_dropout_prob
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion

if output_layer_init_method is None:
output_layer_init_method = init_method
Expand Down Expand Up @@ -230,14 +232,26 @@ def forward(

# [S(0), S(1)] x [S(0), B] = [S(0), S(1)]
if attention_mask is not None:
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
if self.scale_mask_softmax_fusion:
attention_mask = (
attention_mask.expand_as(attention_scores) if use_cache else attention_mask
)
attention_weights = flow._C.fused_scale_mask_softmax_dropout(
attention_scores,
attention_mask,
fill_value=-10000.0,
scale=1,
p=self.attention_dropout_prob,
)[0]
else:
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
else:
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
Expand Down
9 changes: 9 additions & 0 deletions projects/T5/models/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
hidden_dropout_prob,
attention_probs_dropout_prob,
relative_attention_num_buckets,
scale_mask_softmax_fusion=True,
initializer_range=0.02,
layernorm_eps=1e-12,
amp_enabled=False,
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
layer_idx=i,
model_type=model_type,
has_relative_attention_bias=bool(i == 0),
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(
layernorm_epsilon=layernorm_eps,
init_method=init_method,
output_layer_init_method=scaled_init_method,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
layer_idx=i,
model_type=model_type,
has_relative_attention_bias=bool(i - hidden_layers == 0),
Expand Down Expand Up @@ -150,6 +153,7 @@ def from_config(cls, cfg):
"layernorm_eps": cfg.layernorm_eps,
"amp_enabled": cfg.amp_enabled,
"model_type": cfg.model_type,
"scale_mask_softmax_fusion": cfg.scale_mask_softmax_fusion,
}

def forward(
Expand Down Expand Up @@ -311,3 +315,8 @@ def set_pipeline_stage_id(model):
dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx),
)

def set_activation_checkpoint(self):
for module_block in self.t5_model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
6 changes: 6 additions & 0 deletions projects/T5/models/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(
layernorm_epsilon=1e-5,
init_method=nn.init.xavier_normal_,
output_layer_init_method=None,
scale_mask_softmax_fusion=True,
*,
layer_idx=0,
model_type="t5",
Expand All @@ -73,6 +74,7 @@ def __init__(
self.layernorm_epsilon = layernorm_epsilon
self.layer_idx = layer_idx
self.is_decoder = is_decoder
self.scale_mask_softmax_fusion = scale_mask_softmax_fusion

self.init_method = init_method
if output_layer_init_method is None:
Expand All @@ -89,6 +91,7 @@ def __init__(
is_cross_attention=False,
relative_attention_num_buckets=relative_attention_num_buckets,
has_relative_attention_bias=has_relative_attention_bias,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
is_decoder=self.is_decoder,
)
self.post_attention_layernorm = LayerNorm(
Expand All @@ -99,6 +102,7 @@ def __init__(
self.cross_attention = self.build_attention(
is_cross_attention=True,
relative_attention_num_buckets=relative_attention_num_buckets,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
is_decoder=self.is_decoder,
)
self.post_cross_attention_layernorm = LayerNorm(
Expand Down Expand Up @@ -234,6 +238,7 @@ def build_attention(
is_cross_attention=False,
relative_attention_num_buckets=None,
has_relative_attention_bias=False,
scale_mask_softmax_fusion=True,
is_decoder=False,
):
return MultiheadAttention(
Expand All @@ -246,6 +251,7 @@ def build_attention(
output_dropout_prob=self.output_dropout_prob,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
scale_mask_softmax_fusion=scale_mask_softmax_fusion,
layer_idx=self.layer_idx,
has_relative_attention_bias=has_relative_attention_bias,
is_decoder=is_decoder,
Expand Down