diff --git a/llm/auto_parallel/llama/run_llama2.sh b/llm/auto_parallel/llama/run_llama2.sh new file mode 100644 index 000000000000..8edce8d6b848 --- /dev/null +++ b/llm/auto_parallel/llama/run_llama2.sh @@ -0,0 +1,43 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# just for debug + +set -x +unset CUDA_VISIBLE_DEVICES + +task_name="llama3_dp2pp4sd2" +rm -rf output/$task_name/ +rm -rf "output/$task_name""_log" + +export SOT_LOG_LEVEL=4 +export PYTHONPATH=../../../:$PYTHONPATH + +#ulimit -c unlimited +# export GLOG_v=6 +export NCCL_DEBUG=INFO + +# export FLAGS_call_stack_level=3 +# export FLAGS_use_cuda_managed_memory=true + +# export FLAGS_embedding_deterministic=1 +# export FLAGS_cudnn_deterministic=1 +# export NVIDIA_TF32_OVERRIDE=0 +rm -rf core.* +python -u -m paddle.distributed.launch \ + --gpus "0,1,2,3,4,5,6,7" \ + --log_dir "output/$task_name""_log" \ + ./run_pretrain_auto.py \ + ../../../tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json + diff --git a/llm/auto_parallel/llama/run_pretrain_auto.py b/llm/auto_parallel/llama/run_pretrain_auto.py index dfc8e1023bcd..0284b3cc9e4a 100644 --- a/llm/auto_parallel/llama/run_pretrain_auto.py +++ b/llm/auto_parallel/llama/run_pretrain_auto.py @@ -41,14 +41,16 @@ LinearAnnealingWithWarmupDecay, LlamaConfig, LlamaForCausalLM3DAuto, + LlamaForCausalLM3DAutoPP, LlamaForCausalLMNet, LlamaPretrainingCriterion3DAuto, LlamaPretrainingCriterionNet, ) from paddlenlp.utils.log import logger +from paddle.distributed.auto_parallel.pipelining.schedules import ScheduleGPipe MODEL_CLASSES = { - "llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto), + "llama": (LlamaConfig, LlamaForCausalLM3DAutoPP, LlamaPretrainingCriterion3DAuto), "llama_network": (LlamaConfig, LlamaForCausalLMNet, LlamaPretrainingCriterionNet), } diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index a1579173b66a..3f848e1bc3c3 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -57,6 +57,10 @@ from ..quantization.quantization_linear import QuantizationLinear except: QuantizationLinear = None + +from paddle.distributed.auto_parallel.pipelining.schedules import ScheduleGPipe, Schedule1F1B +from paddle.distributed.auto_parallel.pipelining.stage import PipelineStage + MODEL_NAME = "model" OPTIMIZER_NAME = "optimizer" @@ -64,6 +68,63 @@ DIST_MODEL_PATH = "dist_model" FREE_SVAE_LOAD_KEY_PATTERNS = ["learning_rate_", "gradient_merge_", "@GRAD@MERG", "eager_tmp"] +is_split_model = False +local_stage = None + +def manual_model_split(model,stage_idx,group): + global is_split_model + global local_stage + + if is_split_model: + return local_stage + if stage_idx == 0: + for i in range(10): + del model.layers[10] + + def forward0( + self, + input_ids=None, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + outputs = tuple([input_ids, attention_mask, position_ids]) + outputs = tuple([input_ids, attention_mask, position_ids]) + + # decoder layers + for idx, (decoder_layer) in enumerate(self.layers): + outputs = decoder_layer(outputs) + return outputs + setattr(model.__class__, "forward", forward0) + + elif stage_idx == 1: + for i in range(10): + del model.layers[0] + def forward1(self, *args): + outputs = args + # decoder layers + for idx, (decoder_layer) in enumerate(self.layers): + outputs = decoder_layer(outputs) + return outputs + setattr(model.__class__, "forward", forward1) + else: + raise ValueError("Invalid stage index.") + + stage = PipelineStage( + model, + stage_idx, + 2, + group=group + ) + is_split_model = True + local_stage = stage + return stage class AutoTrainer(Trainer): def __init__(self, *args, **kwargs): @@ -88,7 +149,7 @@ def loss_func(loss, outputs): ), "if use AutoTrainer.parallel_model , auto_dist_config obtained from parallel_model should be passed to AutoTrainer " self.auto_dist_config = kwargs.pop("auto_dist_config") model = kwargs["model"] - for param in model.parameters(): + for name, param in model.named_parameters(): # NOTE(zhangwl):in pipeline mode , param my be initialized before while delte init_func ,but param is still not is_initialized if not param._is_initialized() and param._init_func is not None: param.initialize() @@ -98,7 +159,7 @@ def loss_func(loss, outputs): assert self.args.enable_auto_parallel self.global_mesh = fleet.auto.get_mesh() - self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group() + self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group() self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"] @classmethod @@ -670,50 +731,69 @@ def compute_loss(self, model, inputs, return_outputs=False): labels = inputs["generator_labels"] else: labels = None + def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + rank = dist.get_rank() + if rank == 0 or rank == 1 or rank == 2 or rank == 3: + stage = manual_model_split(model, 0, self.comm_group_in_pp) + else: + stage = manual_model_split(model, 1, self.comm_group_in_pp) - outputs = model(**inputs) - - if self.criterion is not None: + schedule = Schedule1F1B(stage, n_microbatches = 2, loss_fn=self.criterion) - def to_list(value): - if value is None: - return value - if isinstance(value, (list, tuple)): - return list(value) - return [value] - - criterion_inputs = to_list(outputs) - criterion_labels = to_list(labels) - loss = self.criterion(*(criterion_inputs + criterion_labels)) - outputs = (loss, outputs) - - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[self.args.past_index] - - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs["loss"] if isinstance(outputs, dict) else outputs - if isinstance(outputs, dict): - loss = outputs["loss"] - elif isinstance(outputs, tuple): - loss = outputs[0] + if rank == 0 or rank == 1 or rank == 2 or rank == 3: + schedule.step(**inputs) else: - loss = outputs - - return (loss, outputs) if return_outputs else loss + losses = [] + schedule.step(target=labels, losses = losses) + print("losses: ", losses) + return 0 + # outputs = model(**inputs) + + # if self.criterion is not None: + + # def to_list(value): + # if value is None: + # return value + # if isinstance(value, (list, tuple)): + # return list(value) + # return [value] + + # criterion_inputs = to_list(outputs) + # criterion_labels = to_list(labels) + # loss = self.criterion(*(criterion_inputs + criterion_labels)) + # outputs = (loss, outputs) + + # # Save past state if it exists + # # TODO: this needs to be fixed and made cleaner later. + # if self.args.past_index >= 0: + # self._past = outputs[self.args.past_index] + + # # We don't use .loss here since the model may return tuples instead of ModelOutput. + # loss = outputs["loss"] if isinstance(outputs, dict) else outputs + # if isinstance(outputs, dict): + # loss = outputs["loss"] + # elif isinstance(outputs, tuple): + # loss = outputs[0] + # else: + # loss = outputs + + # return (loss, outputs) if return_outputs else loss def dynamic_training(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor: with self.autocast_smart_context_manager(): loss = self.compute_loss(model, inputs) - - if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss(): - loss = loss / self.args.gradient_accumulation_steps - - if self.do_grad_scaling: - self.scaler.scale(loss).backward() - else: - loss.backward() + + # if loss is not None and self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss(): + # loss = loss / self.args.gradient_accumulation_steps + + # if self.do_grad_scaling: + # self.scaler.scale(loss).backward() + # else: + # loss.backward() return loss diff --git a/paddlenlp/transformers/llama/__init__.py b/paddlenlp/transformers/llama/__init__.py index a85b249f356d..7cb6d082eb64 100644 --- a/paddlenlp/transformers/llama/__init__.py +++ b/paddlenlp/transformers/llama/__init__.py @@ -15,6 +15,7 @@ from .configuration import * from .modeling import * from .modeling_auto import * +from .modeling_auto_pp import * from .modeling_network import * from .modeling_pp import * from .tokenizer import * diff --git a/paddlenlp/transformers/llama/modeling_auto_pp.py b/paddlenlp/transformers/llama/modeling_auto_pp.py new file mode 100644 index 000000000000..20fc7a3b4882 --- /dev/null +++ b/paddlenlp/transformers/llama/modeling_auto_pp.py @@ -0,0 +1,533 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Llama model""" +from __future__ import annotations + +import math +import os +import warnings +from functools import partial +from typing import Optional, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model +from paddlenlp.utils.tools import get_env_device + +from . import fusion_ops +from .configuration import ( + LLAMA_PRETRAINED_INIT_CONFIGURATION, + LLAMA_PRETRAINED_RESOURCE_FILES_MAP, + LlamaConfig, +) +from .modeling import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaNTKScalingRotaryEmbedding, + LlamaRotaryEmbedding, + _expand_2d_mask, + _make_causal_mask, + apply_rotary_pos_emb, + build_alibi_tensor, + get_triangle_upper_mask, + repeat_kv, +) + +from .modeling_auto import ( + LlamaMLPAuto, + LlamaAttentionAuto, + LlamaPretrainedModelAuto, + LlamaDecoderLayerAuto, + LlamaModelAuto, + LlamaForCausalLM3DAuto, +) + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +__all__ = [ + "LlamaForCausalLM3DAutoPP", +] + + +def enable_fuse_ffn_qkv_pass(): + if os.getenv("FLAGS_enable_fused_ffn_qkv_pass") in [ + "True", + "true", + "1", + ]: + return True + else: + return False + + +def is_pp_enable(): + mesh = fleet.auto.get_mesh() + return "pp" in mesh.dim_names + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +def global_mesh_starts_with_pp(): + mesh = fleet.auto.get_mesh() + if is_pp_enable(): + return mesh.get_mesh_with_dim("pp") + else: + return mesh + +# hidden_states, position_ids, inputs_embeds, attention_mask, output_attentions, past_key_values, use_cache, alibi +# output_attentions、use_cache 可以由config控制且有默认值, delete掉 +# inputs_embeds、past_key_values 动手PP组网没有使用,delete掉,使用默认值 + +# attn_mask_startend_row_indices 自动并行组网没有使用,不考虑 + + +def parse_args(args): + attention_mask, position_ids, alibi = None, None, None + if isinstance(args, tuple): + if len(args) == 4: + hidden_states, attention_mask, position_ids, alibi = args + if len(args) == 3: + hidden_states, attention_mask, position_ids = args + + elif len(args) == 2: + hidden_states, attention_mask = args + + if len(args) == 1: + hidden_states = args[0] + else: + hidden_states = args + + if position_ids is not None: + position_ids.stop_gradient = True + + if attention_mask is not None: + attention_mask.stop_gradient = True + + if alibi is not None: + alibi.stop_gradient = True + + return hidden_states, attention_mask, position_ids, alibi + + +def return_args( + hidden_states, attention_mask = None, position_ids = None, alibi=None +): + ret = (hidden_states,) + + if attention_mask is not None: + ret += (attention_mask.clone(),) + if position_ids is not None: + ret += (position_ids.clone(),) + if alibi is not None: + ret += (alibi.clone(),) + if len(ret) == 1: + ret = ret[0] + + return ret + + +colwise_placements = [dist.Replicate(), dist.Shard(1)] +rowise_placement = [dist.Replicate(), dist.Shard(0)] + + +class LlamaRMSNormAutoPP(nn.Layer): + def __init__(self, config, ipp): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.ipp = ipp + self.weight = dist.shard_tensor( + self.weight, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + def forward(self, args): + hidden_states, attention_mask, position_ids, alibi = parse_args(args) + if self.config.use_fused_rms_norm: + hidden_states = fusion_ops.fusion_rms_norm( + hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm + ) + return return_args(hidden_states, attention_mask, position_ids, alibi) + + + with paddle.amp.auto_cast(False): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + + + return return_args(hidden_states * self.weight, attention_mask, position_ids, alibi) + +class LlamaEmbeddingAutoPP(nn.Layer): + """Extends LlamaEmbeddings to forward attention_mask through the pipeline.""" + + def __init__(self, config): + super(LlamaEmbeddingAutoPP, self).__init__() + self.config = config + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + embedding_placements = ( + [dist.Replicate(), dist.Shard(1)] + if self.config.tensor_parallel_degree > 1 + else [dist.Replicate(), dist.Replicate()] + ) + + self.embed_tokens.weight = dist.shard_tensor( + self.embed_tokens.weight, + get_mesh(), + embedding_placements, + ) + + self.placements = ( + [dist.Shard(1), dist.Shard(0)] if self.config.sequence_parallel else [dist.Shard(0), dist.Replicate()] + ) + + # @property + # def embedding_weight(self): + # return get_attr(self.embed_tokens, "weight") + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) + elif get_env_device() == "gcu": + min_val = paddle.finfo(dtype).min + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(min_val, dtype=dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + return expanded_attn_mask + + def forward(self, args): + input_ids, attention_mask, position_ids, alibi = parse_args(args) + + input_ids.stop_gradient = True + + + # output_hidden_states = ( + # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + # ) + # return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_attentions = self.config.output_attentions + + use_cache = self.config.use_cache + + # retrieve input_ids + + if input_ids is not None: + batch_size, seq_length = input_ids.shape + else: + raise ValueError("You have to specify either decoder_input_ids") + + + past_key_values = tuple([None] * self.config.num_hidden_layers) + + seq_length_with_past = seq_length + cache_length = 0 + + with paddle.amp.auto_cast(False): + inputs_embeds = self.embed_tokens(input_ids) + + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + + global_mesh = global_mesh_starts_with_pp() + if position_ids is None and self.config.sep_parallel_degree > 1: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + if position_ids is not None: + position_ids = dist.shard_tensor( + position_ids, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + # embed positions + if not self.config.use_flash_attention and attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if self.config.alibi: + if attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + alibi_place = [dist.Replicate() for _ in range(len(global_mesh._shape))] + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + alibi = dist.shard_tensor(alibi, global_mesh, alibi_place) + else: + alibi = None + if self.config.use_flash_attention and not self.config.alibi: + # attention_mask in flash_attn is always None for pretrain + # atttenton_mask is used in scaled_dot_product_attention with alibi_tensor + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + attention_mask = dist.shard_tensor( + attention_mask, + global_mesh, + [dist.Replicate() for _ in range(len(global_mesh._shape))], + ) + hidden_states = inputs_embeds + hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) + return return_args( + hidden_states, attention_mask, position_ids, alibi + ) + +class LlamaDecoderLayerAutoPP(nn.Layer): + def __init__(self, config, idx, layerwise_recompute: bool = False, ipp: Optional[int] = None): + super(LlamaDecoderLayerAutoPP, self).__init__() + self.config = config + self.layer_id = idx + self.layer = LlamaDecoderLayerAuto(config, layerwise_recompute, ipp) + self.ipp = ipp + self.enable_recompute = False + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.embed_tokens = None + self.norm = None + self.lm_head = None + if self.layer_id == 0: + self.embed_tokens = LlamaEmbeddingAutoPP(config) + + if self.layer_id == self.config.num_hidden_layers - 1: + self.norm = LlamaRMSNormAutoPP(config, ipp) + self.lm_head = LlamaLMHeadAutoPP(config) + + + def forward(self, args): + if self.embed_tokens is not None: + args = self.embed_tokens(args) + hidden_states, attention_mask, position_ids, alibi = parse_args(args) + output_attentions = self.config.output_attentions + use_cache = self.config.use_cache + + past_key_value = None + + has_gradient = not hidden_states.stop_gradient + + if position_ids is not None: + position_ids_input = dist.reshard( + position_ids, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + else: + position_ids_input = position_ids + attention_mask_input = ( + dist.reshard( + attention_mask, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + if attention_mask is not None + else None + ) + alibi_input = ( + dist.reshard( + alibi, + get_mesh(self.ipp), + [dist.Replicate(), dist.Replicate()], + ) + if alibi is not None + else None + ) + if ( + self.enable_recompute + and self.layer_id not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = recompute( + self.layer, + hidden_states, + position_ids_input, + attention_mask_input, + output_attentions, + past_key_value, + use_cache, + alibi_input, + ) + else: + layer_outputs = self.layer( + hidden_states, + position_ids_input, + attention_mask_input, + output_attentions, + past_key_value, + use_cache, + alibi_input, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + ret_args = return_args( + hidden_states, attention_mask, position_ids, alibi, + ) + if self.norm is not None: + ret_args = self.norm(ret_args) + + if self.lm_head is not None: + ret_args = self.lm_head(ret_args) + return ret_args + +class LlamaLMHeadAutoPP(nn.Layer): + def __init__(self, config: LlamaConfig): + super(LlamaLMHeadAutoPP, self).__init__() + self.config = config + + vocab_size = config.vocab_size + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + self.weight = dist.shard_tensor( + self.weight, + get_mesh(-1), + colwise_placements, + ) + + def forward(self, args): + hidden_states, attention_mask, position_ids, alibi = parse_args(args) + + if self.config.sequence_parallel: + hidden_states = dist.reshard( + hidden_states, + get_mesh(-1), + [dist.Shard(1), dist.Replicate()], + ) + hidden_states = paddle.transpose(hidden_states, [1, 0, 2]) + logits = paddle.matmul(hidden_states, self.weight, transpose_y=False) + return return_args(logits, attention_mask, position_ids, alibi) + + +class LlamaForCausalLM3DAutoPP(LlamaForCausalLM3DAuto): + enable_to_static_method = True + + def __init__(self, config): + super().__init__(config) + self.config = config + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + ## 暂时先不考虑PP,后面再加 + decoder_layers = [] + # self.next_pp_stage_indexes = [] + for i in range(config.num_hidden_layers): + # pp_stage_id, input_need_reshard = get_layer_pp_info(i) + decoder_layers.append(LlamaDecoderLayerAutoPP(config, i, i not in self.no_recompute_layers, i // 10)) + # if input_need_reshard: + # self.next_pp_stage_indexes.append(i) + self.layers = nn.LayerList(decoder_layers) + + def forward( + self, + input_ids=None, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + outputs = return_args(input_ids, attention_mask, position_ids) + + # decoder layers + for idx, (decoder_layer) in enumerate(self.layers): + outputs = decoder_layer(outputs) + + return outputs[0] diff --git a/tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json b/tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json index aa86a1875597..1da19a005b7b 100644 --- a/tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json +++ b/tests/test_tipc/static/auto_parallel/llama2/pretrain_config_llama2_13b/pretrain-llama2_13b.json @@ -3,18 +3,17 @@ "tokenizer_name_or_path": "meta-llama/Llama-2-13b", "input_dir": "./data", "output_dir": "./checkpoints/llama2_pretrain_ckpts", - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 4, + "per_device_train_batch_size": 2, + "gradient_accumulation_steps": 1, "per_device_eval_batch_size": 4, - "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 4, + "tensor_parallel_degree": 2, + "pipeline_parallel_degree": 2, + "sharding_parallel_degree": 2, + "num_hidden_layers": 20, "sharding": "stage1", "data_parallel_config": "enable_allreduce_avg_in_gradinent_scale gradient_sync_after_accumulate", "sharding_parallel_config": "enable_overlap enable_tensor_fusion", "tensor_parallel_config": "enable_mp_async_allreduce", - "pipeline_parallel_config": "enable_send_recv_overlap enable_split_backward", - "pipeline_schedule_mode": "VPP", - "virtual_pp_degree": 5, "sequence_parallel": 0, "use_flash_attention": true, "use_fused_rms_norm": true, @@ -51,6 +50,6 @@ "recompute_granularity": "full", "save_total_limit": 2, "device": "gpu", - "to_static": true, + "to_static": false, "enable_auto_parallel": true }