Skip to content

Commit

Permalink
[Distributed] [CustomDevices] Adapt SP on lora && polish MC2 APIs (#8303
Browse files Browse the repository at this point in the history
)

* [Distributed] adapt sequence parallel on LoRA (#8235)

* [Distributed] [CustomDevices] adapt lora sp && polish MC2 APIs
  • Loading branch information
SylarTiaNII authored and JunnYu committed Apr 24, 2024
1 parent 871070d commit 0f428bb
Show file tree
Hide file tree
Showing 6 changed files with 572 additions and 272 deletions.
278 changes: 260 additions & 18 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import math
import os
from typing import List, Optional

import paddle
Expand All @@ -25,13 +24,25 @@
RowParallelLinear,
)

from .lora_quick_layers import quick_lora
try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
AllGatherOp,
ColumnSequenceParallelLinear,
ReduceScatterOp,
RowSequenceParallelLinear,
mark_as_sequence_parallel_parameter,
)
except:
pass

from paddlenlp.transformers.mc2_parallel_linear import (
MC2ColumnParallelCoreLinear,
MC2ColumnSeqParallelCoreLinear,
MC2RowParallelCoreLinear,
MC2RowSeqParallelCoreLinear,
)

if "npu" in paddle.device.get_all_custom_device_type():
from .mc2_lora_npu import MC2LoRaColumnParallelLinear, MC2LoRaRowParallelLinear
else:
MC2LoRaRowParallelLinear = None
MC2LoRaColumnParallelLinear = None
from .lora_quick_layers import quick_lora


class LoRALinear(nn.Linear):
Expand Down Expand Up @@ -266,16 +277,16 @@ def forward(self, x: paddle.Tensor):
)
else:
# x @ W : [bz, in_f / ws] ===> [bz, out_f]
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
output = MC2LoRaRowParallelLinear.apply(input_mp, self.weight, self.model_parallel_group)
else:
if MC2RowParallelCoreLinear is None:
result_mp = F.linear(x=input_mp, weight=self.weight, name=self.name)
output = mp_ops._mp_allreduce(
result_mp,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True,
)
else:
output = MC2RowParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group)

if not self.merged:
# x @ A: [bz, in_f/ ws] ===> [bz, r]
Expand All @@ -298,6 +309,120 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class RowSequenceParallelLoRALinear(RowSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
use_quick_lora: bool = False,
**kwargs
):
RowSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights

# compatible
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[self.input_size_per_partition, r],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.lora_B = self.create_parameter(
shape=[r, self.out_features],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_A.is_distributed = True
self.lora_A.split_axis = 0
self.lora_B.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_B)
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
return False # self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = False

def eval(self):
super().eval()
if self.merge_weights and not self.merged:
# Merge the weights and mark it
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = True

def forward(self, x: paddle.Tensor):
if not self.input_is_parallel:
input_mp = mp_ops._c_split(x, group=self.model_parallel_group)
else:
input_mp = x

if MC2RowSeqParallelCoreLinear is None:
output_parallel = self.linear(input_mp, self.weight, name=self._name)
output_ = ReduceScatterOp.apply(output_parallel)
result_mp = output_ + self.bias if self.bias is not None else output_
else:
output_ = MC2RowSeqParallelCoreLinear.apply(input_mp, self.weight, self.model_parallel_group)
result_mp = output_ + self.bias if self.bias is not None else output_

if not self.merged:
input_mp = self.lora_dropout(input_mp)
if MC2RowSeqParallelCoreLinear is None:
input_mp = input_mp @ self.lora_A
input_mp = ReduceScatterOp.apply(input_mp)
else:
input_mp = MC2RowSeqParallelCoreLinear.apply(input_mp, self.lora_A, self.model_parallel_group)
delta_mp = (input_mp @ self.lora_B) * self.scaling
result_mp += delta_mp
return result_mp

def extra_repr(self):
name = f", name={self.name}" if self.name else ""
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class ColumnParallelLoRALinear(ColumnParallelLinear):
def __init__(
self,
Expand Down Expand Up @@ -400,21 +525,21 @@ def forward(self, input: paddle.Tensor):
world_size=self.world_size,
)
else:
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
res_mp = MC2LoRaColumnParallelLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias
else:
if MC2ColumnParallelCoreLinear is None:
input_mp = mp_ops._c_identity(input, group=self.model_parallel_group)
result_mp = F.linear(x=input_mp, weight=self.weight, bias=self.bias, name=self.name)
else:
res_mp = MC2ColumnParallelCoreLinear.apply(input, self.weight, self.model_parallel_group)
result_mp = res_mp + self.bias

if not self.merged:
input_a = self.lora_dropout(input) @ self.lora_A
if "npu" in paddle.device.get_all_custom_device_type() and int(os.getenv("MC2", "0")):
tmp = MC2LoRaColumnParallelLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = tmp * self.scaling
else:
if MC2ColumnParallelCoreLinear is None:
input_a_mp = mp_ops._c_identity(input_a, group=self.model_parallel_group)
delta_mp = (input_a_mp @ self.lora_B) * self.scaling
else:
tmp = MC2ColumnParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = tmp * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
Expand All @@ -428,6 +553,123 @@ def extra_repr(self):
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class ColumnSequenceParallelLoRALinear(ColumnSequenceParallelLinear):
def __init__(
self,
in_features: int,
out_features: int,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
rslora: bool = False,
lora_plus_scale: float = 1.0,
merge_weights: bool = True,
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
**kwargs
):
ColumnSequenceParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights

# compatible
self.name = self._name

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[in_features, r],
dtype=self._dtype,
is_bias=False,
attr=lora_A_weight_attr,
)
self.lora_A.is_distributed = False
mark_as_sequence_parallel_parameter(self.lora_A)

self.lora_B = self.create_parameter(
shape=[r, self.output_size_per_partition],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)

self.lora_B.is_distributed = True
self.lora_B.split_axis = 1
if not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

# Freezing the pre-trained weight matrix
self.weight.stop_gradient = True
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0

@property
def use_quick_lora(self):
# TODO(@gexiao): support qlora
return False # self._use_quick_lora and self.training and not self.merged

def train(self):
super().train()
if self.merge_weights and self.merged:
# Make sure that the weights are not merged
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = False

def eval(self):
super().eval()
if self.merge_weights and not self.merged:
# Merge the weights and mark it
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
self.weight.set_value(new_weight)
self.merged = True

def forward(self, x: paddle.Tensor):
if MC2ColumnSeqParallelCoreLinear is None:
if self.is_mp:
input_parallel = AllGatherOp.apply(x)
else:
input_parallel = x
result_mp = self.linear(input_parallel, self.weight, self.bias, name=self._name)
else:
result_mp = MC2ColumnSeqParallelCoreLinear.apply(x, self.weight, self.model_parallel_group)
if self.bias is not None:
result_mp += self.bias

if not self.merged:
input_a = self.lora_dropout(x) @ self.lora_A
if MC2ColumnSeqParallelCoreLinear is None:
input_a = AllGatherOp.apply(input_a)
delta_mp = (input_a @ self.lora_B) * self.scaling
else:
input_a = MC2ColumnSeqParallelCoreLinear.apply(input_a, self.lora_B, self.model_parallel_group)
delta_mp = input_a * self.scaling
result_mp += delta_mp

if self.gather_output and self.is_mp:
result = mp_ops._c_concat(result_mp, group=self.model_parallel_group)
else:
result = result_mp
return result

def extra_repr(self):
name = f", name={self.name}" if self.name else ""
return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}"


class LoRAMergedLinear(nn.Linear):
# LoRA implemented in a dense layer with merged linear weights for q, k, v
def __init__(
Expand Down
Loading

0 comments on commit 0f428bb

Please sign in to comment.