Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix quantization for chatglm #12586

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/base_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Mapping, Optional, Type

import torch
from torch import nn
Expand Down Expand Up @@ -57,6 +57,7 @@ def method_has_implemented_embedding(

class QuantizationConfig(ABC):
"""Base class for quantization configs."""
packed_modules_mapping: Mapping[str, List[str]] = dict()

@abstractmethod
def get_name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def get_quant_method(

# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
if should_ignore_layer(
prefix,
ignore=self.ignore,
packed_modules_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import re
from typing import Iterable, Optional
from types import MappingProxyType
from typing import Iterable, List, Mapping, Optional

from compressed_tensors import CompressionFormat
from torch.nn import Module

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)


def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
Expand All @@ -17,8 +15,11 @@ def is_activation_quantization_format(format: str) -> bool:
return format in _ACTIVATION_QUANTIZATION_FORMATS


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str] = tuple(),
packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

Expand All @@ -30,8 +31,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING and layer_name not in ignore:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in packed_modules_mapping and layer_name not in ignore:
shard_proj_names = packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/layers/quantization/quark/quark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform

__all__ = ["QuarkLinearMethod"]
Expand Down Expand Up @@ -56,7 +54,10 @@ def get_quant_method(self, layer: torch.nn.Module,

# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers):
if should_ignore_layer(
prefix,
ignore=exclude_layers,
packed_modules_mapping=self.packed_modules_mapping):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
Expand Down Expand Up @@ -199,8 +200,8 @@ def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:

proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in self.packed_modules_mapping:
shard_proj_names = self.packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
17 changes: 9 additions & 8 deletions vllm/model_executor/layers/quantization/quark/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import re
from typing import Any, Iterable, Optional

from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from types import MappingProxyType
from typing import Any, Iterable, List, Mapping, Optional


def deep_compare(dict1: Any, dict2: Any) -> bool:
Expand All @@ -18,8 +16,11 @@ def deep_compare(dict1: Any, dict2: Any) -> bool:
return dict1 == dict2


def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
def should_ignore_layer(
layer_name: Optional[str],
ignore: Iterable[str],
packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
if layer_name is None:
return False

Expand All @@ -31,8 +32,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
if proj_name in packed_modules_mapping:
shard_proj_names = packed_modules_mapping[proj_name]

# Convert fused_name --> [shard_names]
shard_names = [
Expand Down
21 changes: 9 additions & 12 deletions vllm/model_executor/layers/quantization/utils/quant_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This file is used for /tests and /benchmarks"""
from typing import List, Optional
from types import MappingProxyType
from typing import List, Mapping, Optional

import numpy
import torch
Expand All @@ -11,14 +12,6 @@
SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]

# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}


def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
Expand Down Expand Up @@ -63,14 +56,18 @@ def unpack_quantized_values_into_int32(w_q: torch.Tensor,
return res.permute(inv_perm)


def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
def is_layer_skipped(
prefix: str,
ignored_layers: List[str],
packed_modules_mapping: Mapping[str, List[str]] = MappingProxyType({})
) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
if proj_name in packed_modules_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name]
for shard_proj_name in packed_modules_mapping[proj_name]
]

is_skipped = None
Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ def _initialize_model(
model_config = vllm_config.model_config
model_class, _ = get_model_architecture(model_config)

# pass packed_modules_mapping by reference to quant_config
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None and vllm_config.quant_config is not None:
vllm_config.quant_config.packed_modules_mapping = packed_mapping

signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
Expand Down
24 changes: 19 additions & 5 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,14 @@ def __init__(
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)

# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
Expand Down Expand Up @@ -325,6 +327,7 @@ def __init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()

Expand All @@ -336,6 +339,7 @@ def __init__(
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)

self.activation_func = SiluAndMul()
Expand All @@ -346,6 +350,7 @@ def __init__(
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)

def forward(self, hidden_states):
Expand Down Expand Up @@ -394,7 +399,7 @@ def __init__(
config.hidden_size, eps=config.layernorm_epsilon)

# MLP
self.mlp = GLMMLP(config, quant_config)
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")

def forward(
self,
Expand Down Expand Up @@ -505,7 +510,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.embedding")

self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
Expand Down Expand Up @@ -764,6 +770,7 @@ class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP,
SupportsMultiModal):
# Ensure that the LoRA support check passes when the class is not
# initialized, but set all these attributes to empty.
# These will be updated when a model class is selected
packed_modules_mapping = {}
supported_lora_modules = []
embedding_modules = {}
Expand All @@ -775,9 +782,16 @@ def __new__(
prefix: str = "",
) -> None:
config = vllm_config.model_config.hf_config

# Initialize VL
if hasattr(config, "vision_config"):
return ChatGLMV(vllm_config=vllm_config, prefix=prefix)
if hasattr(config, "vision_config"): # noqa: SIM108
instance_cls = ChatGLMV
# Initialize LLM
else:
return ChatGLM(vllm_config=vllm_config, prefix=prefix)
instance_cls = ChatGLM

cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)
cls.supported_lora_modules += instance_cls.supported_lora_modules
cls.embedding_modules.update(instance_cls.embedding_modules)
cls.embedding_padding_modules += instance_cls.embedding_padding_modules
return instance_cls(vllm_config=vllm_config, prefix=prefix)
31 changes: 22 additions & 9 deletions vllm/model_executor/models/glm4_vision_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ def __init__(
self.head_dim,
config.num_heads,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)

self.attn = MultiHeadAttention(self.num_heads_per_rank, self.head_dim,
Expand All @@ -99,6 +101,7 @@ def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
super().__init__()
self.config = config
Expand All @@ -107,11 +110,13 @@ def __init__(
config.hidden_size,
config.intermediate_size,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.fc2 = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -135,7 +140,9 @@ def __init__(
self.attention = Attention(config,
quant_config=quant_config,
prefix=f"{prefix}.attention")
self.mlp = MLP(config, quant_config=quant_config)
self.mlp = MLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.post_attention_layernorm = LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)

Expand All @@ -162,7 +169,7 @@ def __init__(
self.layers = nn.ModuleList([
TransformerLayer(config,
quant_config=quant_config,
prefix=f"{prefix}.layer.{layer_idx}")
prefix=f"{prefix}.layers.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])

Expand All @@ -179,6 +186,7 @@ def __init__(
config,
in_features,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = '',
):
"""
The original implementation is the same as:
Expand Down Expand Up @@ -220,20 +228,24 @@ def __init__(
self.linear_proj = ReplicatedLinear(in_features,
config.hidden_size,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.norm1 = nn.LayerNorm(config.hidden_size)
self.act1 = nn.GELU()
self.act2 = SiluAndMul()

self.merged_proj = MergedColumnParallelLinear(
config.hidden_size, [config.ffn_hidden_size] * 2,
bias=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.merged_proj")

self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config)
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h")

def forward(self, x):
x, _ = self.linear_proj(x)
Expand All @@ -260,7 +272,8 @@ def __init__(
prefix=f"{prefix}.transformer")
self.linear_proj = GLU(config,
in_features=config.hidden_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.linear_proj")
self.conv = nn.Conv2d(in_channels=vision_config.hidden_size,
out_channels=config.hidden_size,
kernel_size=2,
Expand Down
Loading