Skip to content

Commit

Permalink
[Shardformer] Add parallel output for shardformer models(bloom, falco…
Browse files Browse the repository at this point in the history
…n) (hpcaitech#5702)

* [pre-commit.ci] auto fixes from pre-commit.com hooks

* add parallel cross entropy output for falcon model & fix some typos in bloom.py

* fix module name error, self.model -> self.transformers in bloom, falcon model

* Fix the overflow bug of distributed cross entropy loss function when training with fp16

* add dtype to parallel cross entropy loss function

* fix dtype related typos adn prettify the loss.py

* fix grad dtype and update dtype mismatch error

* fix typo bugs
  • Loading branch information
Hz188 authored May 21, 2024
1 parent 9d83c6d commit 22ce873
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 17 deletions.
15 changes: 9 additions & 6 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def forward(
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
Expand All @@ -34,7 +35,7 @@ def forward(
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
Expand Down Expand Up @@ -86,7 +87,7 @@ def forward(
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)

# calculate the loss
Expand All @@ -97,9 +98,10 @@ def forward(
loss = torch.sum(loss).div_(num_non_zero)

# calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype

return loss

Expand All @@ -114,11 +116,11 @@ def backward(ctx, grad_output):
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)

update = 1.0 - mask.view(-1).float()
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update

grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None
return grad_logits, None, None, None, None, None


def cross_entropy_1d(
Expand All @@ -127,5 +129,6 @@ def cross_entropy_1d(
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
100 changes: 95 additions & 5 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
Expand All @@ -27,6 +28,8 @@
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d

logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -354,7 +357,7 @@ def bloom_for_causal_lm_forward(
past_key_values = None
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
lm_logits = self.lm_head(hidden_states).contiguous()

loss = None
if labels is not None:
Expand All @@ -365,10 +368,21 @@ def bloom_for_causal_lm_forward(
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
loss = loss_fct(shift_logits, shift_labels.view(-1))

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -1065,3 +1079,79 @@ def forward(
)

return forward


def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import BloomForCausalLM

def forward(
self: BloomForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
new_vocab_size = lm_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

return forward
99 changes: 96 additions & 3 deletions colossalai/shardformer/modeling/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
QuestionAnsweringModelOutput,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
Expand All @@ -31,6 +32,8 @@
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.shard import ShardConfig

from ..layer import cross_entropy_1d


def build_falcon_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
def build_falcon_alibi_tensor(
Expand Down Expand Up @@ -437,14 +440,28 @@ def falcon_for_causal_lm_forward(
loss = None
if labels is not None:
# Shift so that tokens < n predict n
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
)
if shard_config.enable_tensor_parallelism and shard_config.parallel_output:
new_vocab_size = shift_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size),
shift_labels.view(batch_size * seq_length),
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -747,3 +764,79 @@ def falcon_for_question_answering_forward(
else:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}


def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import FalconForCausalLM

def forward(
self: FalconForCausalLM,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
past_key_values = None
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
new_vocab_size = shift_logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)

return forward
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ def gpt2_lmhead_model_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)
else:
loss = loss_fct(shift_logits, shift_labels)
Expand Down Expand Up @@ -1294,6 +1295,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.transformer.dtype,
)

if not return_dict:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def llama_for_causal_lm_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
Expand Down Expand Up @@ -768,6 +769,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)

if not return_dict:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def mistral_for_causal_lm_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
Expand Down Expand Up @@ -701,6 +702,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)

if not return_dict:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def opt_for_causal_lm_forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)
else:
loss_fct = CrossEntropyLoss()
Expand Down Expand Up @@ -988,6 +989,7 @@ def forward(
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.decoder.dtype,
)

if not return_dict:
Expand Down
Loading

0 comments on commit 22ce873

Please sign in to comment.