diff --git a/paddlenlp/peft/lora/lora_layers.py b/paddlenlp/peft/lora/lora_layers.py index e0c79c47a87a..a31f7c3a33b1 100644 --- a/paddlenlp/peft/lora/lora_layers.py +++ b/paddlenlp/peft/lora/lora_layers.py @@ -13,7 +13,6 @@ # limitations under the License. import math -import os from typing import List, Optional import paddle @@ -25,13 +24,25 @@ RowParallelLinear, ) -from .lora_quick_layers import quick_lora +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + AllGatherOp, + ColumnSequenceParallelLinear, + ReduceScatterOp, + RowSequenceParallelLinear, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +from paddlenlp.transformers.mc2_parallel_linear import ( + MC2ColumnParallelCoreLinear, + MC2ColumnSeqParallelCoreLinear, + MC2RowParallelCoreLinear, + MC2RowSeqParallelCoreLinear, +) -if "npu" in paddle.device.get_all_custom_device_type(): - from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear -else: - MC2LoRaRowParallelLinear = None - MC2LoRaColumnParallelLinear = None +from .lora_quick_layers import quick_lora class LoRALinear(nn.Linear): @@ -266,9 +277,7 @@ def forward(self, x: paddle.Tensor): ) else: # x @ W : [bz, in_f / ws] ===> [bz, out_f] - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): - output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group) - else: + if MC2RowParallelCoreLinear is None: result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name) output = mp_ops._mp_allreduce( result_mp, @@ -276,6 +285,8 @@ def forward(self, x: paddle.Tensor): use_calc_stream=True, use_model_parallel=True, ) + else: + output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) if not self.merged: # x @ A: [bz, in_f/ ws] ===> [bz, r] @@ -298,6 +309,120 @@ def extra_repr(self): return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" +class RowSequenceParallelLoRALinear(RowSequenceParallelLinear): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus_scale: float = 1.0, + merge_weights: bool = True, + use_quick_lora: bool = False, + **kwargs + ): + RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) + if not isinstance(r, int) or r <= 0: + raise ValueError("Lora rank r should be a positive integer") + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + # compatible + self.name = self._name + + # Actual trainable parameters + self.lora_A = self.create_parameter( + shape=[self.input_size_per_partition, r], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu") + ), + ) + self.lora_B = self.create_parameter( + shape=[r, self.out_features], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + + self.lora_A.is_distributed = True + self.lora_A.split_axis = 0 + self.lora_B.is_distributed = False + mark_as_sequence_parallel_parameter(self.lora_B) + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) + + # Freezing the pre-trained weight matrix + self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + # TODO(@gexiao): support qlora + return False # self._use_quick_lora and self.training and not self.merged + + def train(self): + super().train() + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = False + + def eval(self): + super().eval() + if self.merge_weights and not self.merged: + # Merge the weights and mark it + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = True + + def forward(self, x: paddle.Tensor): + if not self.input_is_parallel: + input_mp = mp_ops._c_split(x, group=self.model_parallel_group) + else: + input_mp = x + + if MC2RowSeqParallelCoreLinear is None: + output_parallel = self.linear(input_mp, self.weight, name=self._name) + output_ = ReduceScatterOp.apply(output_parallel) + result_mp = output_ + self.bias if self.bias is not None else output_ + else: + output_ = MC2RowSeqParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group) + result_mp = output_ + self.bias if self.bias is not None else output_ + + if not self.merged: + input_mp = self.lora_dropout(input_mp) + if MC2RowSeqParallelCoreLinear is None: + input_mp = input_mp @ self.lora_A + input_mp = ReduceScatterOp.apply(input_mp) + else: + input_mp = MC2RowSeqParallelCoreLinear.apply(input_mp, self.lora_A, self.model_parallel_group) + delta_mp = (input_mp @ self.lora_B) * self.scaling + result_mp += delta_mp + return result_mp + + def extra_repr(self): + name = f", name={self.name}" if self.name else "" + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" + + class ColumnParallelLoRALinear(ColumnParallelLinear): def __init__( self, @@ -400,21 +525,21 @@ def forward(self, input: paddle.Tensor): world_size=self.world_size, ) else: - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): - res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group) - result_mp = res_mp + self.bias - else: + if MC2ColumnParallelCoreLinear is None: input_mp = mp_ops._c_identity(input, group=self.model_parallel_group) result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name) + else: + res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group) + result_mp = res_mp + self.bias if not self.merged: input_a = self.lora_dropout(input) @ self.lora_A - if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")): - tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group) - delta_mp = tmp * self.scaling - else: + if MC2ColumnParallelCoreLinear is None: input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group) delta_mp = (input_a_mp @ self.lora_B) * self.scaling + else: + tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group) + delta_mp = tmp * self.scaling result_mp += delta_mp if self.gather_output and self.is_mp: @@ -428,6 +553,123 @@ def extra_repr(self): return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" +class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear): + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + rslora: bool = False, + lora_plus_scale: float = 1.0, + merge_weights: bool = True, + lora_A_weight_attr: Optional[paddle.ParamAttr] = None, + use_quick_lora: bool = False, + **kwargs + ): + ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs) + if not isinstance(r, int) or r <= 0: + raise ValueError("Lora rank r should be a positive integer") + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + + # compatible + self.name = self._name + + # Actual trainable parameters + self.lora_A = self.create_parameter( + shape=[in_features, r], + dtype=self._dtype, + is_bias=False, + attr=lora_A_weight_attr, + ) + self.lora_A.is_distributed = False + mark_as_sequence_parallel_parameter(self.lora_A) + + self.lora_B = self.create_parameter( + shape=[r, self.output_size_per_partition], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + + self.lora_B.is_distributed = True + self.lora_B.split_axis = 1 + if not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) + + # Freezing the pre-trained weight matrix + self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + + @property + def use_quick_lora(self): + # TODO(@gexiao): support qlora + return False # self._use_quick_lora and self.training and not self.merged + + def train(self): + super().train() + if self.merge_weights and self.merged: + # Make sure that the weights are not merged + new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = False + + def eval(self): + super().eval() + if self.merge_weights and not self.merged: + # Merge the weights and mark it + new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling + self.weight.set_value(new_weight) + self.merged = True + + def forward(self, x: paddle.Tensor): + if MC2ColumnSeqParallelCoreLinear is None: + if self.is_mp: + input_parallel = AllGatherOp.apply(x) + else: + input_parallel = x + result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name) + else: + result_mp = MC2ColumnSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group) + if self.bias is not None: + result_mp += self.bias + + if not self.merged: + input_a = self.lora_dropout(x) @ self.lora_A + if MC2ColumnSeqParallelCoreLinear is None: + input_a = AllGatherOp.apply(input_a) + delta_mp = (input_a @ self.lora_B) * self.scaling + else: + input_a = MC2ColumnSeqParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group) + delta_mp = input_a * self.scaling + result_mp += delta_mp + + if self.gather_output and self.is_mp: + result = mp_ops._c_concat(result_mp, group=self.model_parallel_group) + else: + result = result_mp + return result + + def extra_repr(self): + name = f", name={self.name}" if self.name else "" + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" + + class LoRAMergedLinear(nn.Linear): # LoRA implemented in a dense layer with merged linear weights for q, k, v def __init__( diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 1bbd0284823c..41ab1e681e24 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -45,14 +45,25 @@ from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME from ...utils.log import logger from .lora_config import LoRAConfig -from .lora_layers import ( - ColumnParallelLoRALinear, - ColumnParallelLoRAMergedLinear, - LoRAConv2D, - LoRALinear, - LoRAMergedLinear, - RowParallelLoRALinear, -) + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ) + + from .lora_layers import ( + ColumnParallelLoRALinear, + ColumnParallelLoRAMergedLinear, + ColumnSequenceParallelLoRALinear, + LoRAConv2D, + LoRALinear, + LoRAMergedLinear, + RowParallelLoRALinear, + RowSequenceParallelLoRALinear, + ) +except: + pass try: from ...quantization.quantization_linear import ( @@ -454,6 +465,58 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) # Lora column parallel will spilt lora A matrix self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, ColumnSequenceParallelLinear): + # recover the original output_features + output_features = module.weight.shape[1] * module.world_size + lora_module = ColumnSequenceParallelLoRALinear( + in_features=module.weight.shape[0], + out_features=output_features, + gather_output=module.gather_output, + has_bias=module.bias is not None, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + merge_weights=lora_config.merge_weights, + lora_A_weight_attr=paddle.ParamAttr( + initializer=nn.initializer.KaimingUniform( + negative_slope=math.sqrt(5), nonlinearity="leaky_relu" + ) + ), + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora B matrix + self.add_lora_split_mapping(module_name + ".lora_B", is_column=True) + + # for lora qat + if self.lora_config.do_qat: + self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True) + self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False) + self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False) + elif isinstance(module, RowSequenceParallelLinear): + # recover the original output_features + lora_module = RowSequenceParallelLoRALinear( + in_features=module.weight.shape[0] * module.world_size, + out_features=module.weight.shape[1], + has_bias=module.bias is not None, + input_is_parallel=module.input_is_parallel, + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + merge_weights=lora_config.merge_weights, + use_quick_lora=lora_config.use_quick_lora, + ) + # Lora column parallel will spilt lora A matrix + self.add_lora_split_mapping(module_name + ".lora_A", is_column=False) + # for lora qat if self.lora_config.do_qat: self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False) @@ -539,7 +602,7 @@ def _find_and_replace_module(self, model, module_name, lora_config, enable_lora) ) if lora_module is None: raise ValueError( - f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear. {module}({module_name}) is not supported。" + f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。" ) if getattr(lora_module, "quant_weight", None) is not None: lora_module.quant_weight = module.quant_weight @@ -597,6 +660,8 @@ def mark_only_lora_as_trainable(self) -> None: or isinstance(layer, LoRAConv2D) or isinstance(layer, ColumnParallelLoRALinear) or isinstance(layer, RowParallelLoRALinear) + or isinstance(layer, ColumnSequenceParallelLoRALinear) + or isinstance(layer, RowSequenceParallelLoRALinear) or isinstance(layer, LoRAMergedLinear) or isinstance(layer, ColumnParallelLoRAMergedLinear) or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear)) @@ -684,9 +749,11 @@ def restore_original_model(self): self._find_and_restore_module(layer_name) elif ( isinstance(layer, ColumnParallelLoRALinear) + or isinstance(layer, ColumnSequenceParallelLoRALinear) or isinstance(layer, LoRAConv2D) or isinstance(layer, ColumnParallelLoRAMergedLinear) or isinstance(layer, RowParallelLoRALinear) + or isinstance(layer, RowSequenceParallelLoRALinear) or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear)) or ( ColumnParallelQuantizationLoRALinear is not None diff --git a/paddlenlp/peft/lora/mc2_lora_npu.py b/paddlenlp/peft/lora/mc2_lora_npu.py deleted file mode 100644 index 7ae47b1496f7..000000000000 --- a/paddlenlp/peft/lora/mc2_lora_npu.py +++ /dev/null @@ -1,80 +0,0 @@ -# !/usr/bin/env python3 - -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" mc2(tp overlap) """ - -import paddle -import paddle_custom_device -from paddle.autograd import PyLayer - - -class MC2LoRaRowParallelLinear(PyLayer): - @staticmethod - def forward(ctx, input_, weight, group): - ctx.save_for_backward(input_, weight) - rank = paddle.distributed.get_rank() - hcom_name = group.process_group.get_comm_name(rank) - x = input_.reshape([-1, input_.shape[-1]]) - out = paddle_custom_device.npu.fused_mm_allreduce( - x, weight, bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0 - ) - output = out.reshape([input_.shape[0], input_.shape[1], weight.shape[1]]) - ctx.ring_id = group.id - return output - - @staticmethod - def backward(ctx, dy): - input_, weight = ctx.saved_tensor() - out_grad = dy - sub_grad = out_grad.reshape([-1, out_grad.shape[-1]]) - input_grad = paddle.matmul(sub_grad, weight, transpose_y=True) - if weight.stop_gradient: - return input_grad.reshape(input_.shape) - else: - input_reshape = input_.reshape([-1, input_.shape[-1]]) - weight_grad = paddle.matmul(input_reshape, sub_grad, transpose_x=True) - return input_grad.reshape(input_.shape), weight_grad - - -class MC2LoRaColumnParallelLinear(PyLayer): - @staticmethod - def forward(ctx, input_, weight, group): - ctx.save_for_backward(input_, weight) - ctx.group = group - input_mp = input_ - result_mp = paddle.matmul(input_mp, weight) - return result_mp - - @staticmethod - def backward(ctx, dy): - input_, weight = ctx.saved_tensor() - sub_grad = dy.reshape([-1, dy.shape[-1]]) - rank = paddle.distributed.get_rank() - hcom_name = ctx.group.process_group.get_comm_name(rank) - - d_weight = ( - paddle.matmul(input_.reshape([-1, input_.shape[-1]]), sub_grad, transpose_x=True) - if not weight.stop_gradient - else None - ) - d_input = paddle_custom_device.npu.fused_mm_allreduce( - sub_grad, weight.t(), bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0 - ) - - if d_weight is not None: - return d_input.reshape(input_.shape), d_weight - else: - return d_input.reshape(input_.shape) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 5cb13f7aa61a..38f1d244bdf2 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -62,6 +62,10 @@ def swiglu(x, y=None): init_name_mappings, ) from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies +from paddlenlp.transformers.mc2_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, +) from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -96,13 +100,6 @@ def swiglu(x, y=None): ] -def is_mc2_valid(): - current_device = get_env_device() - if current_device == "npu": - return True - return False - - def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -574,12 +571,7 @@ def __init__(self, config): self.fuse_attention_ffn = config.fuse_attention_ffn if config.sequence_parallel: - if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): - from paddlenlp.transformers.mc2_seqence_parallel_linear import ( - MC2ColumnSeqParallelLinear, - MC2RowSeqParallelLinear, - ) - + if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear else: @@ -697,12 +689,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = False if config.sequence_parallel: - if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): - from paddlenlp.transformers.mc2_seqence_parallel_linear import ( - MC2ColumnSeqParallelLinear, - MC2RowSeqParallelLinear, - ) - + if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear else: diff --git a/paddlenlp/transformers/mc2_parallel_linear.py b/paddlenlp/transformers/mc2_parallel_linear.py new file mode 100644 index 000000000000..066e8074e21f --- /dev/null +++ b/paddlenlp/transformers/mc2_parallel_linear.py @@ -0,0 +1,230 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import paddle + +try: + import paddle_custom_device +except ImportError: + pass + +from paddle import distributed as dist +from paddle.autograd import PyLayer + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, + ) +except: + pass +from paddlenlp.utils.tools import get_env_device + +__all_gather_recomputation__ = False +if int(os.getenv("MC2_Recompute", 0)): + __all_gather_recomputation__ = True + + +def is_mc2_valid(): + current_device = get_env_device() + if current_device == "npu": + return int(os.getenv("MC2", 0)) + return 0 + + +if is_mc2_valid(): + + class MC2ColumnParallelCoreLinear(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.save_for_backward(input_, weight) + ctx.group = group + input_mp = input_ + result_mp = paddle.matmul(input_mp, weight) + return result_mp + + @staticmethod + def backward(ctx, dy): + input_, weight = ctx.saved_tensor() + sub_grad = dy.reshape([-1, dy.shape[-1]]) + rank = paddle.distributed.get_rank() + hcom_name = ctx.group.process_group.get_comm_name(rank) + + d_weight = ( + paddle.matmul(input_.reshape([-1, input_.shape[-1]]), sub_grad, transpose_x=True) + if not weight.stop_gradient + else None + ) + d_input = paddle_custom_device.npu.fused_mm_allreduce( + sub_grad, weight.t(), bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0 + ) + + if d_weight is not None: + return d_input.reshape(input_.shape), d_weight + else: + return d_input.reshape(input_.shape), None + + class MC2RowParallelCoreLinear(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.save_for_backward(input_, weight) + rank = paddle.distributed.get_rank() + hcom_name = group.process_group.get_comm_name(rank) + x = input_.reshape([-1, input_.shape[-1]]) + out = paddle_custom_device.npu.fused_mm_allreduce( + x, weight, bias=None, hcom=hcom_name, reduce_op="sum", comm_turn=0 + ) + output = out.reshape([input_.shape[0], input_.shape[1], weight.shape[1]]) + ctx.ring_id = group.id + return output + + @staticmethod + def backward(ctx, dy): + input_, weight = ctx.saved_tensor() + out_grad = dy + sub_grad = out_grad.reshape([-1, out_grad.shape[-1]]) + input_grad = paddle.matmul(sub_grad, weight, transpose_y=True) + if weight.stop_gradient: + return input_grad.reshape(input_.shape), None + else: + input_reshape = input_.reshape([-1, input_.shape[-1]]) + weight_grad = paddle.matmul(input_reshape, sub_grad, transpose_x=True) + return input_grad.reshape(input_.shape), weight_grad + + class MC2ColumnSeqParallelCoreLinear(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.weight_stop_gradient = weight.stop_gradient + ctx.save_for_backward(input_, weight) + + rank = dist.get_rank() + hcomm_info = group.process_group.get_comm_name(rank) + + world_size = group.nranks + output, gather_out = paddle_custom_device.npu.fused_allgather_mm( + input_, + weight, + bias=None, + hcom=hcomm_info, + world_size=world_size, + gather_index=0, + gather_output=(not __all_gather_recomputation__), + comm_turn=0, + ) + + ctx.all_gather_output = gather_out + ctx.world_size = world_size + ctx.group = group + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensor() + + if __all_gather_recomputation__: + dim_size = input_.shape + dim_size[0] = dim_size[0] * ctx.world_size + all_gather_output = paddle.empty(dim_size, dtype=input_.dtype) + all_gather_output.stop_gradient = True + all_gather_work = dist.stream.all_gather(all_gather_output, input_, group=ctx.group, sync_op=False) + else: + all_gather_output = ctx.all_gather_output + + grad_input = paddle.matmul(grad_output, weight, transpose_y=True) + sub_grad_input = paddle.empty(input_.shape, dtype=input_.dtype) + reduce_scatter_work = dist.stream.reduce_scatter( + sub_grad_input, grad_input, group=ctx.group, sync_op=False + ) + + if __all_gather_recomputation__: + all_gather_work.wait() + + grad_weight = ( + paddle.matmul(all_gather_output, grad_output, transpose_x=True) + if not ctx.weight_stop_gradient + else None + ) + reduce_scatter_work.wait() + + return sub_grad_input, grad_weight + + class MC2RowSeqParallelCoreLinear(PyLayer): + @staticmethod + def forward(ctx, input_, weight, group): + ctx.weight_stop_gradient = weight.stop_gradient + ctx.save_for_backward(input_, weight) + + rank = dist.get_rank() + hcomm_info = group.process_group.get_comm_name(rank) + world_size = group.nranks + + output = paddle_custom_device.npu.fused_mm_reduce_scatter( + input_, + weight, + bias=None, + hcom=hcomm_info, + world_size=world_size, + reduce_op="sum", + comm_turn=0, + ) + + ctx.hcomm_info = hcomm_info + ctx.world_size = world_size + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight = ctx.saved_tensor() + hcomm_info = ctx.hcomm_info + world_size = ctx.world_size + + grad_input, all_gather_grad_output = paddle_custom_device.npu.fused_allgather_mm( + grad_output, + weight.t(), + bias=None, + hcom=hcomm_info, + world_size=world_size, + gather_index=0, + gather_output=True, + comm_turn=0, + ) + grad_weight = ( + paddle.matmul(input_, all_gather_grad_output, transpose_x=True) + if not ctx.weight_stop_gradient + else None + ) + + return grad_input, grad_weight + + class MC2ColumnSeqParallelLinear(ColumnSequenceParallelLinear): + def forward(self, x): + output = MC2ColumnSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group) + output = output + self.bias if self.bias is not None else output + return output + + class MC2RowSeqParallelLinear(RowSequenceParallelLinear): + def forward(self, x): + output = MC2RowSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group) + output = output + self.bias if self.bias is not None else output + return output + +else: + MC2ColumnSeqParallelCoreLinear = None + MC2RowSeqParallelCoreLinear = None + MC2ColumnSeqParallelLinear = None + MC2RowSeqParallelLinear = None + MC2ColumnParallelCoreLinear = None + MC2RowParallelCoreLinear = None diff --git a/paddlenlp/transformers/mc2_seqence_parallel_linear.py b/paddlenlp/transformers/mc2_seqence_parallel_linear.py deleted file mode 100644 index c39a78cc6252..000000000000 --- a/paddlenlp/transformers/mc2_seqence_parallel_linear.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import paddle - -try: - import paddle_custom_device -except ImportError: - raise ImportError("Current device does not support MC2!") - -from paddle import distributed as dist -from paddle.autograd import PyLayer - -try: - from paddle.distributed.fleet.utils.sequence_parallel_utils import ( - ColumnSequenceParallelLinear, - RowSequenceParallelLinear, - ) -except: - pass - -__all_gather_recomputation__ = False -if int(os.getenv("MC2_Recompute", 0)): - __all_gather_recomputation__ = True - - -class MC2Column(PyLayer): - @staticmethod - def forward(ctx, input_, weight, group): - ctx.save_for_backward(input_, weight) - - rank = dist.get_rank() - hcomm_info = group.process_group.get_comm_name(rank) - - world_size = group.nranks - output, gather_out = paddle_custom_device.npu.fused_allgather_mm( - input_, - weight, - bias=None, - hcom=hcomm_info, - world_size=world_size, - gather_index=0, - gather_output=(not __all_gather_recomputation__), - comm_turn=0, - ) - - ctx.all_gather_output = gather_out - ctx.world_size = world_size - ctx.group = group - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight = ctx.saved_tensor() - - if __all_gather_recomputation__: - dim_size = input_.shape - dim_size[0] = dim_size[0] * ctx.world_size - all_gather_output = paddle.empty(dim_size, dtype=input_.dtype) - all_gather_output.stop_gradient = True - all_gather_work = dist.stream.all_gather(all_gather_output, input_, group=ctx.group, sync_op=False) - else: - all_gather_output = ctx.all_gather_output - - grad_input = paddle.matmul(grad_output, weight, transpose_y=True) - sub_grad_input = paddle.empty(input_.shape, dtype=input_.dtype) - reduce_scatter_work = dist.stream.reduce_scatter(sub_grad_input, grad_input, group=ctx.group, sync_op=False) - - if __all_gather_recomputation__: - all_gather_work.wait() - - grad_weight = paddle.matmul(all_gather_output, grad_output, transpose_x=True) - reduce_scatter_work.wait() - - return sub_grad_input, grad_weight - - -class MC2Row(PyLayer): - @staticmethod - def forward(ctx, input_, weight, group): - ctx.save_for_backward(input_, weight) - - rank = dist.get_rank() - hcomm_info = group.process_group.get_comm_name(rank) - world_size = group.nranks - - output = paddle_custom_device.npu.fused_mm_reduce_scatter( - input_, - weight, - bias=None, - hcom=hcomm_info, - world_size=world_size, - reduce_op="sum", - comm_turn=0, - ) - - ctx.hcomm_info = hcomm_info - ctx.world_size = world_size - return output - - @staticmethod - def backward(ctx, grad_output): - input_, weight = ctx.saved_tensor() - hcomm_info = ctx.hcomm_info - world_size = ctx.world_size - - grad_input, all_gather_grad_output = paddle_custom_device.npu.fused_allgather_mm( - grad_output, - weight.t(), - bias=None, - hcom=hcomm_info, - world_size=world_size, - gather_index=0, - gather_output=True, - comm_turn=0, - ) - grad_weight = paddle.matmul(input_, all_gather_grad_output, transpose_x=True) - - return grad_input, grad_weight - - -class MC2ColumnSeqParallelLinear(ColumnSequenceParallelLinear): - def forward(self, x): - output = MC2Column.apply(x, self.weight, self.model_parallel_group) - output = output + self.bias if self.bias is not None else output - return output - - -class MC2RowSeqParallelLinear(RowSequenceParallelLinear): - def forward(self, x): - output = MC2Row.apply(x, self.weight, self.model_parallel_group) - output = output + self.bias if self.bias is not None else output - return output