Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fuse multi head att #417

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
18 changes: 10 additions & 8 deletions dev/model_loader_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@ export TEST_OUTPUT=output_unittest
export ONEFLOW_TEST_DEVICE_NUM=4
export ONEFLOW_EP_CUDA_ENABLE_TF32_EXECUTION=0

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_bert_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_bert_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_roberta_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_roberta_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_gpt_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_gpt_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_mt5_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_mt5_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_t5_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_t5_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swin_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swin_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swinv2_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_swinv2_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_vit_loader.py
# python3 -m oneflow.distributed.launch --nproc_per_node 4 -m pytest -s --disable-warnings tests/model_utils/test_vit_loader.py

python3 -m oneflow.distributed.launch --nproc_per_node 1 -m pytest -s --disable-warnings tests/model_utils/test_mt5_loader_2.py

rm -rf $TEST_OUTPUT
11 changes: 7 additions & 4 deletions projects/T5/configs/mt5_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
train_data_path = "projects/T5/data/training_data/part_0"
pretrained_model_path = None

micro_batch_size = 64
micro_batch_size = 4
optim["lr"] = 1e-4

# dataloader
Expand All @@ -30,7 +30,7 @@
)
],
collate_fn=collate_fn(
vocab_size=12902,
vocab_size=12900,
max_seq_length=512,
noise_density=0.15,
mean_noise_span_length=3,
Expand All @@ -43,7 +43,7 @@
model = LazyCall(T5ForPreTraining)(cfg=cfg)

# model config
model.cfg.vocab_size = 12902
model.cfg.vocab_size = 12900
model.cfg.hidden_size = 512
model.cfg.hidden_layers = 8
model.cfg.num_attention_heads = 6
Expand All @@ -53,6 +53,7 @@
model.cfg.attention_probs_dropout_prob = 0.0
model.cfg.embedding_dropout_prob = 0.0
model.cfg.layernorm_eps = 1e-6

model.cfg.model_type = "mt5"
model.cfg.pretrained_model_path = pretrained_model_path

Expand All @@ -63,7 +64,7 @@
train_epoch=1,
train_iter=24000,
log_period=10,
amp=dict(enabled=False),
amp=dict(enabled=True),
warmup_ratio=1 / 24,
# checkpointer=dict(period=10, max_to_keep=20),
dist=dict(
Expand All @@ -89,3 +90,5 @@

train.zero_optimization.enabled = True
train.zero_optimization.stage = 2
train.activation_checkpoint.enabled = False
train.num_accumulation_steps = 8
54 changes: 15 additions & 39 deletions projects/T5/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,13 @@ def forward(
use_cache (bool, optional): it will be set to True, when the model is in the inference
phase and used for incremental decoding. Defaults to False.
"""

# hidden_states, encoder_states: [S(0), B]
# attention_mask: [S(0), B]

if encoder_states is not None:
encoder_states = encoder_states.to_global(placement=hidden_states.placement)

if attention_mask is not None:
attention_mask = attention_mask.to_global(placement=hidden_states.placement)

bsz, real_seq_length = hidden_states.size()[:2]
real_seq_length, bsz = hidden_states.size()[:2]

if past_key_value is not None:
assert (
Expand All @@ -166,47 +162,39 @@ def forward(
f"Got {len(past_key_value)} past states.\n"
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

key_length = real_seq_length if encoder_states is None else encoder_states.shape[1]

key_length = real_seq_length if encoder_states is None else encoder_states.shape[0]
if self.is_cross_attention:
# if it is cross attention, key and value should be calculated only once, and the
# result can be reused.
query = self.query(hidden_states)
query = query.view(bsz, -1, self.num_heads, self.head_size)
query = query.permute(0, 2, 1, 3)
query = query.view(-1, bsz, self.num_heads, self.head_size)
query = query.permute(1, 2, 0, 3) # bsz, num_head, seq_len, head_size

if past_key_value is not None:
key, value = past_key_value
elif encoder_states is not None:
key_value = self.key_value(encoder_states)
key_value = key_value.view(bsz, -1, self.num_heads, 2 * self.head_size)
key_value = key_value.permute(0, 2, 1, 3)
key_value = key_value.view(-1, bsz, self.num_heads, 2 * self.head_size)
key_value = key_value.permute(1, 2, 0, 3)
key, value = flow.chunk(key_value, chunks=2, dim=-1)
else:
raise ValueError(
"past_key_value and encoder_states cannot be None at the same time."
)
else:
# if it is self attention, query, key, and value are all obtained from hidden_states.
# when in the inference phase of an incremental decoder,
# hidden_states is the last-added state,
# the full key and value could be obtained by concatenating with past_key_value.
query_key_value = self.query_key_value(hidden_states)
query_key_value = query_key_value.view(bsz, -1, self.num_heads, 3 * self.head_size)
query_key_value = query_key_value.permute(
0, 2, 1, 3
) # [bsz, num_heads, src_len, 3 * head_size]
query, key, value = flow.chunk(query_key_value, chunks=3, dim=-1)
attention_scores, value = flow._C.fused_self_attention(
query_key_value, head_size=self.head_size, alpha=1
)
if past_key_value is not None:
past_key, past_value = past_key_value
key = flow.cat((past_key.type_as(key), key), dim=2)
value = flow.cat((past_value.type_as(value), value), dim=2)

# query, key, value: [S(0), S(1)], shape: [bsz, num_heads, seq_length, head_size]
if use_cache:
past_key_value = (key, value)

# [bsz, num_heads, tgt_len, src_len] with [S(0), S(1)]
attention_scores = flow.matmul(query, key, transpose_b=True)
if self.is_cross_attention:
attention_scores = flow.matmul(query, key, transpose_b=True, alpha=1)

if position_bias is None:
if not self.has_relative_attention_bias:
Expand All @@ -228,30 +216,19 @@ def forward(

attention_scores = attention_scores + position_bias

# [S(0), S(1)] x [S(0), B] = [S(0), S(1)]
if attention_mask is not None:
attention_scores = flow.mul(attention_scores, attention_mask)
attention_scores = attention_scores - 10000.0 * (1 - attention_mask)
# TODO(xingyu.liao): graph will occur `where_scalar` errors
# when using `masked_fill`
# attention_scores = attention_scores.masked_fill(1 - attention_mask, -10000.0)
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)
else:
attention_weights = flow.softmax(attention_scores, dim=-1)
# [bsz, num_heads, tgt_len, src_len]
attention_weights = self.dropout(attention_weights)

# Context shape: [bsz, num_heads, tgt_len, head_size] with [S(0), S(1)]
context = flow.matmul(attention_weights, value)
# Change shape: [bsz, num_heads, tgt_len, head_size] -> [bsz, tgt_len, num_heads, head_size]
context = context.transpose(1, 2)

# Concat multi-head results from
# [bsz, tgt_len, num_heads, head_size] -> [bsz, tgt_len, num_heads * head_size]
# SBP sign: [S(0), S(2)]
# [S(0), S(2)] x [B, S(0)] = [S(0), P] -> [S(0), B]
context = flow._C.transpose(context, perm=(2, 0, 1, 3))

output = self.dense(context.flatten(2))

output = self.output_dropout(output)
Expand All @@ -272,7 +249,6 @@ def extra_repr(self) -> str:
def _relative_position_bucket(
self, relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
# relative_position: (seq_len, seq_len)
relative_buckets = 0
if bidirectional:
num_buckets //= 2
Expand Down
19 changes: 14 additions & 5 deletions projects/T5/models/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def forward(
encoder_decoder_position_bias = None
self.set_cache(encoder_states=None, past_key_values=None)
encoder_attn_mask = self.extended_attn_mask(encoder_attn_mask)
enc_embedding_output = self.embedding(encoder_input_ids)
enc_hidden_states = enc_embedding_output

enc_hidden_states = self.embedding(encoder_input_ids)

enc_hidden_states = enc_hidden_states.transpose(0, 1)

for layer in self.encoder.layers:
enc_hidden_states, position_bias = layer(
Expand All @@ -192,8 +194,10 @@ def forward(
)
encoder_decoder_attn_mask = self.extended_attn_mask(encoder_decoder_attn_mask)

dec_embedding_output = self.embedding(decoder_input_ids)
dec_hidden_states = dec_embedding_output
dec_hidden_states = self.embedding(decoder_input_ids)

dec_hidden_states = dec_hidden_states.transpose(0, 1)

if use_cache:
presents = []

Expand Down Expand Up @@ -270,7 +274,7 @@ def forward(
encoder_decoder_attn_mask,
use_cache=use_cache,
)

logits = logits.transpose(0, 1)
if lm_labels is not None:
lm_loss = self.loss_func(logits, lm_labels, loss_mask)
return lm_loss
Expand Down Expand Up @@ -311,3 +315,8 @@ def set_pipeline_stage_id(model):
dist_utils.get_layer_stage_id(model.t5_model.decoder.final_layernorm.layer_idx),
dist.get_layer_placement(model.t5_model.decoder.final_layernorm.layer_idx),
)

def set_activation_checkpoint(self):
for module_block in self.t5_model.modules():
if isinstance(module_block.origin, TransformerLayer):
module_block.config.activation_checkpointing = True
2 changes: 0 additions & 2 deletions projects/T5/models/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,8 @@ def forward(
use_cache: it will be set to `True` when the model is in the inference phase and
used for incremental decoding.
"""
# Change placement for pipeline parallelsim
hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))

# hidden_states shape: (batch_size, seq_length, hidden_size)
if attention_mask is not None:
attention_mask = attention_mask.to_global(
placement=dist.get_layer_placement(self.layer_idx)
Expand Down
Loading