Skip to content

Commit

Permalink
[Bugfix] fix moe_wna16 get_quant_method (vllm-project#12648)
Browse files Browse the repository at this point in the history
Fix vllm-project#12647
The `get_quant_method` of `moe_wna16` always return moe method,
GPTQ-based linear method or AWQ-based linear method, even when the
target module is attention layer.


https://github.com/vllm-project/vllm/blob/baeded25699f9f4851843306f27f685c4d4ee7c5/vllm/attention/layer.py#L86-L92

Signed-off-by: Jinzhen Lin <[email protected]>
  • Loading branch information
jinzhen-lin authored and sahelib25 committed Feb 3, 2025
1 parent 77e1722 commit 9972fa1
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
AWQLinearMethod)
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig, AWQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig, GPTQMarlinLinearMethod)
GPTQMarlinConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

Expand Down Expand Up @@ -131,18 +128,18 @@ def get_quant_method(self, layer: torch.nn.Module,
else:
if self.linear_quant_method == "gptq":
if self.use_marlin:
return GPTQMarlinLinearMethod(
GPTQMarlinConfig.from_config(self.full_config))
return GPTQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQLinearMethod(
GPTQConfig.from_config(self.full_config))
return GPTQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq":
if self.use_marlin:
return AWQMarlinLinearMethod(
AWQMarlinConfig.from_config(self.full_config))
return AWQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQLinearMethod(
AWQConfig.from_config(self.full_config))
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
raise ValueError("moe_wna16 only support gptq and awq.")

Expand Down

0 comments on commit 9972fa1

Please sign in to comment.