Skip to content

Commit

Permalink
Refactor Linear handling in TransformersModel (#12727)
Browse files Browse the repository at this point in the history
Signed-off-by: Harry Mellor <[email protected]>
  • Loading branch information
hmellor authored Feb 5, 2025
1 parent 64862d1 commit 249824c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 58 deletions.
30 changes: 15 additions & 15 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import itertools
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset):


def adjust_bitsandbytes_4bit_shard(param: Parameter,
shard_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
shard_offsets: dict[str, tuple[int, int]],
loaded_shard_id: str) -> tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""

total, _ = shard_offsets["total"]
Expand Down Expand Up @@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase):
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
Expand Down Expand Up @@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase):

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
Expand Down Expand Up @@ -179,7 +179,8 @@ def __init__(
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self,
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
raise NotImplementedError


Expand Down Expand Up @@ -240,9 +241,8 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)

def forward(
self, x: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
def forward(self,
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
Expand Down Expand Up @@ -288,7 +288,7 @@ def __init__(self,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
Expand Down Expand Up @@ -374,7 +374,7 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
bias = self.bias if not self.skip_bias_add else None

# Matrix multiply.
Expand Down Expand Up @@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):

def __init__(self,
input_size: int,
output_sizes: List[int],
output_sizes: list[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
Expand Down Expand Up @@ -500,7 +500,7 @@ def weight_loader(self,
current_shard_offset = 0
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
shard_offsets: List[Tuple[int, int, int]] = []
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
Expand Down Expand Up @@ -602,7 +602,7 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter,
"""

current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = []
shard_offsets: list[tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def weight_loader_v2(self, param: BasevLLMParameter,

param.load_row_parallel_weight(loaded_weight=loaded_weight)

def forward(self, input_):
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
if self.input_is_parallel:
input_parallel = input_
else:
Expand Down
76 changes: 33 additions & 43 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0

# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,7 +15,7 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from typing import Iterable, List, Optional, Set, Tuple, Union
from typing import Iterable, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -71,23 +72,10 @@ def vllm_flash_attention_forward(
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward


# Linear Layer that is compatible with transformers internal forward
# TODO: This is a temporary solution, we should find a better way to integrate
class HFColumnParallelLinear(ColumnParallelLinear):

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]


class HFRowParallelLinear(RowParallelLinear):

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]


def replace_tp_linear_class(orig_module: nn.Linear,
style: str,
quant_config=None):
def replace_linear_class(
linear: nn.Linear,
style: str,
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
Expand All @@ -99,26 +87,28 @@ def replace_tp_linear_class(orig_module: nn.Linear,
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")

input_size = orig_module.in_features
output_size = orig_module.out_features
bias = orig_module.bias is not None
vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style)

if style == "colwise":
return HFColumnParallelLinear(
input_size,
output_size,
bias,
)
elif style == "rowwise":
return HFRowParallelLinear(
input_size,
output_size,
bias,
)
# We don't consider colwise_rep since it's used in lm_head
else:
if vllm_linear_cls is None:
raise ValueError(f"Unsupported parallel style value: {style}")

class HFCompatibleLinear(vllm_linear_cls):
"""
Wrapper class that removes `output_bias` from returned output.
"""

def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]

return HFCompatibleLinear(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
)


class TransformersModel(nn.Module):
embedding_padding_modules = ["lm_head"]
Expand Down Expand Up @@ -192,16 +182,16 @@ def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
"support it yet!")

for child_name, child_module in module.named_children():
qual_name = prefix + child_name
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_tp_linear_class(
child_module, style, self.quant_config)
new_module = replace_linear_class(child_module, style,
self.quant_config)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
else:
self.tensor_parallelize(child_module, prefix=f"{qual_name}.")
self.tensor_parallelize(child_module, prefix=qual_name)

def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
Expand All @@ -219,7 +209,7 @@ def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor], # argument not used
kv_caches: list[torch.Tensor], # argument not used
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -249,10 +239,10 @@ def sample(self, logits: torch.Tensor,
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params = set[str]()
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
Expand Down

0 comments on commit 249824c

Please sign in to comment.