Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed May 16, 2024
1 parent 886e87c commit f812587
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 8 deletions.
8 changes: 6 additions & 2 deletions lmdeploy/legacy/pytorch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch import nn

from lmdeploy.lite.utils.cal_qparams import QParams

try:
import awq_inference_engine
except ModuleNotFoundError:
Expand Down Expand Up @@ -72,7 +74,8 @@ def __init__(
def from_linear(cls: Type['WeightOnlyQLinear'],
linear: nn.Linear,
quantizer: TypeVar('Quantizer'),
awq_layout: bool = True) -> 'WeightOnlyQLinear':
awq_layout: bool = True,
qparams: Optional[QParams] = None) -> 'WeightOnlyQLinear':
"""Create a WeightOnlyQLinear object from a PyTorch Linear object.
Args:
Expand Down Expand Up @@ -103,7 +106,8 @@ def from_linear(cls: Type['WeightOnlyQLinear'],
group_size)
qlinear.bias = linear.bias

qparams = quantizer.calculate_qparams(linear.weight)
if qparams is not None:
qparams = quantizer.calculate_qparams(linear.weight)
i32_w = quantizer.quant(linear.weight, qparams, real=True)
i32_w = i32_w.t().contiguous()

Expand Down
3 changes: 2 additions & 1 deletion lmdeploy/lite/apis/auto_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def auto_awq(model: str,
fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type]
input_stats = torch.load(work_dir / 'inputs_stats.pth')
act_scales = input_stats['absmax']
layers = collect_target_modules(model, layer_type)
fcs = {}
for l_name, layer in layers.items():
Expand All @@ -84,9 +83,11 @@ def auto_awq(model: str,

if search_scale:
awq_ratios = input_stats['ratios']
act_scales = input_stats['absmean']
awq_layers(layers, fc2fcs, norm2fcs, act_scales, awq_ratios,
w_group_size, device)
else:
act_scales = input_stats['absmax']
smooth_layers(layers, fc2fcs, norm2fcs, act_scales, w_group_size,
device)
quant_weights(model, fcs, w_bits, w_sym, w_group_size, device)
Expand Down
13 changes: 13 additions & 0 deletions lmdeploy/lite/quantization/activation/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ActivationObserver(GlobalAvailMixin):
Also keeps track of the number of batches observed.
"""
observed = False

def __init__(self, dim: int) -> None:
"""Constructor for ActivationObserver.
Expand All @@ -80,6 +81,16 @@ def __init__(self, dim: int) -> None:
self.ratio = None
self.num_ratio_tracked = 0

@classmethod
def disable(cls):
"""To avoid recomputation in search scale process."""
cls.observed = True

@classmethod
def enable(cls):
"""To avoid recomputation in search scale process."""
cls.observed = False

@torch.no_grad()
def observe(self, x: torch.Tensor, save_input: bool = False) -> None:
"""Function to observe the input tensor and update the max, min, mean,
Expand All @@ -88,6 +99,8 @@ def observe(self, x: torch.Tensor, save_input: bool = False) -> None:
Args:
x : Input tensor
"""
if self.observed:
return
assert len(x.shape) == 3
assert x.size(2) == self.dim
cur_val = x.flatten(0, 1)
Expand Down
22 changes: 17 additions & 5 deletions lmdeploy/lite/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def smooth_ln_fcs(ln: torch.nn.Module,
w_scales = get_weight_scale(concat_w, group_size)

scales = (act_scales.pow(alpha) /
w_scales.pow(1 - alpha)).to(device).to(dtype)
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)

scales = scales / (scales[nonzero_positions].max() *
scales[nonzero_positions].min()).sqrt()
Expand Down Expand Up @@ -151,7 +151,7 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
w_scales = get_weight_scale(concat_w, group_size)

scales = (act_scales.pow(alpha) /
w_scales.pow(1 - alpha)).to(device).to(dtype)
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt()

# (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale
Expand Down Expand Up @@ -211,11 +211,16 @@ def quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda'):
"""Quantize the weights of the target model's linear layers."""
from lmdeploy.legacy.pytorch.modules import WeightOnlyQLinear
from lmdeploy.lite.quantization import WeightQuantizer
from lmdeploy.lite.utils import QParams
for name, fc in fcs.items():
fc.to(device)
quantizer = WeightQuantizer(bits, symmetry, 'per_group', group_size)
q_linear = WeightOnlyQLinear.from_linear(fc, quantizer)

fc.weight.data, scales, zeros = pseudo_quantize_tensor(
fc.weight.data, bits, group_size, return_scale_zeros=True)
q_linear = WeightOnlyQLinear.from_linear(fc,
quantizer,
qparams=QParams(
scales, zeros))
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
fc.to('cpu')
Expand Down Expand Up @@ -253,7 +258,10 @@ def smooth_layers(layers,
print(f'{l_name} smooth weight done.')


def pseudo_quantize_tensor(w, w_bit=8, w_group_size=-1):
def pseudo_quantize_tensor(w,
w_bit=8,
w_group_size=-1,
return_scale_zeros=False):
"""Pseudo quantize tensor."""
org_w_shape = w.shape
if w_group_size > 0:
Expand All @@ -274,6 +282,10 @@ def pseudo_quantize_tensor(w, w_bit=8, w_group_size=-1):
assert torch.isnan(w).sum() == 0

w = w.reshape(org_w_shape)
if return_scale_zeros:
zeros = zeros.view(org_w_shape[0], -1)
scales = scales.view(org_w_shape[0], -1)
return w, scales, zeros
return w


Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/lite/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,10 @@ def _forward(mod, *args, **kwargs):
*batch_args[i], **batch_kwargs[i]))
obs_group = ActivationObserver.find_group(self.inp_obs_group)
mod_name = self.mod2name[mod]
ActivationObserver.disable()
auto_scale_block(mod, batch_kwargs[i], self.w_bits,
self.w_group_size, obs_group, mod_name)
ActivationObserver.enable()
for key, item in obs_group.items():
if key.startswith(f'{mod_name}.'):
item.value.cpu()
Expand Down

0 comments on commit f812587

Please sign in to comment.