Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] fix moe_wna16 get_quant_method #12648

Merged
merged 1 commit into from
Feb 2, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
AWQLinearMethod)
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig, AWQMarlinLinearMethod)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq import (GPTQConfig,
GPTQLinearMethod)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig, GPTQMarlinLinearMethod)
GPTQMarlinConfig)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

Expand Down Expand Up @@ -131,18 +128,18 @@ def get_quant_method(self, layer: torch.nn.Module,
else:
if self.linear_quant_method == "gptq":
if self.use_marlin:
return GPTQMarlinLinearMethod(
GPTQMarlinConfig.from_config(self.full_config))
return GPTQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQLinearMethod(
GPTQConfig.from_config(self.full_config))
return GPTQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq":
if self.use_marlin:
return AWQMarlinLinearMethod(
AWQMarlinConfig.from_config(self.full_config))
return AWQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQLinearMethod(
AWQConfig.from_config(self.full_config))
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
raise ValueError("moe_wna16 only support gptq and awq.")

Expand Down