Skip to content

Commit

Permalink
Support internlm3 8b
Browse files Browse the repository at this point in the history
  • Loading branch information
AllentDan committed Jan 14, 2025
1 parent 4ac1894 commit 9c13834
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 4 deletions.
3 changes: 3 additions & 0 deletions lmdeploy/lite/apis/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'InternLM2ForCausalLM': 'InternLM2DecoderLayer',
'InternLM3ForCausalLM': 'InternLM3DecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'Qwen2ForCausalLM': 'Qwen2DecoderLayer',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
Expand All @@ -34,6 +35,7 @@
NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm',
'InternLM2ForCausalLM': 'InternLM2RMSNorm',
'InternLM3ForCausalLM': 'InternLM3RMSNorm',
'QWenLMHeadModel': 'RMSNorm',
'Qwen2ForCausalLM': 'Qwen2RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
Expand All @@ -52,6 +54,7 @@
HEAD_NAME_MAP = {
'InternLMForCausalLM': 'lm_head',
'InternLM2ForCausalLM': 'output',
'InternLM3ForCausalLM': 'output',
'QWenLMHeadModel': 'lm_head',
'Qwen2ForCausalLM': 'lm_head',
'BaiChuanForCausalLM': 'lm_head', # Baichuan 7B
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/lite/apis/smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def smooth_quant(model: str,
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_linear)
linear.to('cpu')
q_linear.to('cpu')
torch.cuda.empty_cache()

for name, norm in rmsnorms.items():
if skipped_module(name):
Expand All @@ -111,6 +113,8 @@ def smooth_quant(model: str,
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_norm)
norm.to('cpu')
q_linear.to('cpu')
torch.cuda.empty_cache()

if vl_model:
from .auto_awq import save_vl_model
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/lite/quantization/activation/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def observe(self, x: torch.Tensor, save_input: bool = False) -> None:
return
assert x.size(-1) == self.dim
cur_val = x.flatten(0, 1)
if any([s == 0 for s in cur_val.shape]):
return
cur_max = cur_val.max(0)[0].cpu()
cur_min = cur_val.min(0)[0].cpu()
cur_mean = cur_val.mean(0).cpu()
Expand Down
37 changes: 33 additions & 4 deletions lmdeploy/lite/quantization/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
'attention_norm': ['attention.wqkv'],
'ffn_norm': ['feed_forward.w1', 'feed_forward.w3']
},
'InternLM3DecoderLayer': {
'input_layernorm':
['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
},
'QWenBlock': {
'ln_1': ['attn.c_attn'],
'ln_2': ['mlp.w1', 'mlp.w2']
Expand Down Expand Up @@ -72,6 +77,10 @@
'InternLM2DecoderLayer': {
'feed_forward.w3': ['feed_forward.w2']
},
'InternLM3DecoderLayer': {
'self_attn.v_proj': ['self_attn.o_proj'],
'mlp.up_proj': ['mlp.down_proj']
},
'QWenBlock': {
'attn.c_attn': ['attn.c_proj'],
'mlp.w1': ['mlp.c_proj']
Expand Down Expand Up @@ -304,6 +313,7 @@ def quant_weights(model, fcs, bits, symmetry, group_size=-1, device='cuda'):
scales, zeros))
setattr(parent, child_name, q_linear)
fc.to('cpu')
torch.cuda.empty_cache()

print(f'{name} weight {pack_or_skip}.')

Expand All @@ -318,22 +328,37 @@ def smooth_layers(layers,

for l_name, layer in layers.items():
layer.to(device)
submodule_names = [name for name, _ in layer.named_modules()]
for ln_name, fc_names in norm2fcs.items():
a_name = [f'{l_name}.{n}' for n in fc_names][0]
a_name = [
f'{l_name}.{n}' for n in fc_names if n in submodule_names
][0]

ln = layer.get_submodule(ln_name)
fcs = [layer.get_submodule(n) for n in fc_names]
fcs = [
layer.get_submodule(n) for n in fc_names
if n in submodule_names
]
smooth_ln_fcs(ln, fcs, a_scales[a_name], group_size)

for f_name, fc_names in fc2fcs.items():
a_name = [f'{l_name}.{n}' for n in fc_names][0]
a_name = [
f'{l_name}.{n}' for n in fc_names if n in submodule_names
][0]

fc = layer.get_submodule(f_name)
fcs = [layer.get_submodule(n) for n in fc_names]
fcs = [
layer.get_submodule(n) for n in fc_names
if n in submodule_names
]

smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size)

layer.to('cpu')
torch.cuda.empty_cache()
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
print(f'{l_name} smooth weight done.'
f' max gpu memory: {max_memory:.2f} GB')
print(f'{l_name} smooth weight done.')


Expand Down Expand Up @@ -402,4 +427,8 @@ def awq_layers(layers,
smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size, ratio)

layer.to('cpu')
torch.cuda.empty_cache()
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
print(f'{l_name} smooth weight done.'
f' max gpu memory: {max_memory:.2f} GB')
print(f'{l_name} smooth weight done.')

0 comments on commit 9c13834

Please sign in to comment.