-
Notifications
You must be signed in to change notification settings - Fork 4
/
export_quantized_model.py
245 lines (215 loc) · 11.1 KB
/
export_quantized_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# from rwkv_src.modeling_rwkv6 import Rwkv6ForCausalLM
from rwkv_src.rwkv_model import RWKV_RNN, sample_logits
from transformers import AutoTokenizer
import types
import torch
import numpy as np
import onnx
from utils.model_utils import get_dummy_input_for_rwkv_causal_llm, get_input_output_names, split_onnx, get_dummy_state_kvcache
from quantizers.base_quantizer import LLMQuantizer
from utils.dataset_builder import DatasetBuilder
import argparse
import os
import re
import subprocess
import json
from pathlib import Path
import copy
parser = argparse.ArgumentParser(description='Convert model')
parser.add_argument('model', type=Path, help='Path to RWKV pth file')
parser.add_argument('--linear_param_encodings', type=Path, default=None, help='Path to linear param encodings')
parser.add_argument('--calib_data_path', type=Path, default=None, help='Path to calibration data')
parser.add_argument('--weights_bitwidth', type=int, default=8, help='Weights bitwidth')
parser.add_argument('--use_cuda', action='store_true', default=True, help='Use CUDA')
parser.add_argument('--test_generate', action='store_true', default=False, help='Test generate')
parser.add_argument('--num_chunks', type=int, default=2, help='Number of chunks')
args_parser = parser.parse_args()
device = torch.device("cuda") if args_parser.use_cuda and torch.cuda.is_available() else torch.device("cpu")
# TODO: add more while keeping the precision
quant_list = [
"att.output",
"ffn",
]
if args_parser.linear_param_encodings:
with open(args_parser.linear_param_encodings, "r") as f:
encodings = json.load(f)
encodings_new = copy.deepcopy(encodings)
for k, v in encodings['param_encodings'].items():
if not any([x in k for x in quant_list]):
del encodings_new['param_encodings'][k]
with open(str(args_parser.linear_param_encodings).replace('.encodings', '_processed.encodings'), "w") as f:
json.dump(encodings_new, f, indent=4)
del encodings_new
del encodings
args = types.SimpleNamespace()
##############################
args.quant_scheme = "tf"
args.activation_bit_width = 16
args.parameter_bit_width = args_parser.weights_bitwidth
args.in_place_quantsim = False
args.config_file = "quantizers/configs/qsim_config_per_channel_with_exceptions.json"
args.num_cands = 20
args.export_dir = "quant_export"
args.output_dir = "quant_export"
args.model_name = str(args_parser.model).replace(".pth", "").split("/")[-1]
args.input_symmetry = "symqt"
args.exceptions_file = "quantizers/configs/rwkv_activation_exceptions.json"
args.act_mse_loss_type = "mse"
args.parameter_encoding_file = str(args_parser.linear_param_encodings).replace('.encodings', '_processed.encodings') if args_parser.linear_param_encodings else None
args.encoding_path = None
args.do_actmse = False
args.disable_act_quantizers = False
args.fp16 = True
args.do_train = False
args.clip_activation = None
args.load_sim_checkpoint = False
args.save_sim_checkpoint = False
##############################
args.calib_dataset_name = "wikitext"
args.calib_dataset_config_name = "wikitext-2-raw-v1"
args.dataset_cache_dir = "./dataset_cache"
args.calib_dataset_split = None
args.calib_dataset_preprocessor = "gpt2"
args.eval_dataset_name = "wikitext"
args.eval_dataset_config_name = "wikitext-103-raw-v1"
args.eval_dataset_split = "test"
args.eval_dataset_preprocessor = "gptq"
args.num_calibration_batches = 20
args.per_device_calib_batch_size = 1
args.per_device_eval_batch_size = 1
args.block_size = 1024
args.seed = 1234
##############################
qnn_sdk_root = os.environ["QNN_SDK_ROOT"]
if not qnn_sdk_root:
print("Please set QNN_SDK_ROOT environment variable to the root of the Qualcomm Neural Processing SDK")
exit(1)
device = torch.device("cuda") if args_parser.use_cuda and torch.cuda.is_available() else torch.device("cpu")
args.device = device
model_args = types.SimpleNamespace()
model_args.USE_CUDA = args_parser.use_cuda
model_args.fp16 = False
model_args.wkv_customop = False
model_args.USE_EMBEDDING = True
model_args.MODEL_NAME = str(args_parser.model)
model_args.RESCALE_LAYER = 0
model_args.eos_token_id = 0
model = RWKV_RNN(model_args)
tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-5-world-1b5", trust_remote_code=True)
tokenizer.model_max_length = 1024
dummy_input = get_dummy_input_for_rwkv_causal_llm(1, 1, device, model_cfg=model.args)
dataset_builder = DatasetBuilder(args)
dataset_builder.make_dataset(tokenizer=tokenizer, args=args, column_name="text", shuffle=True)
quantizer = LLMQuantizer(model, args, model.args)
quantizer.orig_model = model
quantizer.prepare_quantsim(dummy_input, args, dataset_builder.train_dataloader, tokenizer)
def test_generate(model, tokenizer,device='cuda'):
print("Generating inference using QuantSim model")
prompt = "\n我们发现,"
print(prompt, end='')
input_ids = tokenizer(prompt, return_tensors='pt')
model = model.to(device)
states = get_dummy_state_kvcache(1, model.args, device)
inputs = {'in0': input_ids['input_ids'].to(device), 'state': states}
logits, states = model(**inputs)
logits = logits[:, -1, :]
for i in range(100):
token = sample_logits(logits.flatten().cpu())
if token == 0:
break
print(tokenizer.decode(token), end='', flush=True)
inputs = {'in0': torch.LongTensor([[token]]).to(device), 'state': states}
logits, states = model(**inputs)
if args_parser.test_generate:
test_generate(quantizer.quant_sim.model, tokenizer=tokenizer,device=args.device)
else:
input_names, output_names = get_input_output_names(model.args)
dummy_input = get_dummy_input_for_rwkv_causal_llm(1, 1, device, model_cfg=model.args)
quantizer.export_quantsim(dummy_input=dummy_input, input_names=input_names, output_names=output_names, opset_version=17)
model_args = model.args
print("Post processing ONNX model")
onnx_path = os.path.join(args.export_dir, "onnx", f"{args.model_name}.onnx")
model = onnx.load(onnx_path, load_external_data=False)
graph = model.graph
nodes = graph.node
pattern = r"blocks\.(\d+)"
for i in range(1, len(graph.input), 3):
graph.input[i].name = "layer" + str((i-1)//3) + "_state0_in"
graph.input[i+1].name = "layer" + str((i-1)//3) + "_state1_in"
graph.input[i+2].name = "layer" + str((i-1)//3) + "_state2_in"
onnx.save_model(model, onnx_path, save_as_external_data=True, all_tensors_to_one_file=True, size_threshold=1024, convert_attribute=False)
print("Post processing encodings")
encodings = None
with open(onnx_path.replace('.onnx', '.encodings'), "+r") as f:
encodings = json.load(f)
if args_parser.calib_data_path is not None:
graph = model.graph
float_override = [{"bitwidth": 16, "dtype": "float"}]
act_int_override = [{"bitwidth": 16, "dtype": "int"}]
for i in range(len(graph.node)):
if "matmul_kv" in graph.node[i].name \
or "mul_time_decay" in graph.node[i].name \
or "add_time_decay1" in graph.node[i].name \
or "ln" in graph.node[i].name:
for j in graph.node[i].output:
encodings['activation_encodings'][j] = float_override
if "ln" in graph.node[i].name:
for j in graph.node[i].input:
encodings['activation_encodings'][j] = float_override
if "add_time_first" in graph.node[i].name:
for j in graph.node[i].input:
if "state" in j:
encodings['activation_encodings'][j] = float_override
for j in graph.node[i].output:
encodings['activation_encodings'][j] = float_override
# a16w8 head
if "head" in graph.node[i].name:
for j in graph.node[i].output:
encodings['activation_encodings'][j] = act_int_override
with open(onnx_path.replace('.onnx', '.encodings'), "w") as f:
json.dump(encodings, f, indent=4)
split_onnx(onnx_path, args.model_name, args_parser.num_chunks, args.export_dir, False)
layers_per_chunk = len(quantizer.quant_sim.model.blocks) // args_parser.num_chunks
os.path.exists(os.path.join(args.export_dir, f"sample_inputs")) or os.mkdir(os.path.join(args.export_dir, f"sample_inputs"))
sample_input_path = os.path.join(args.export_dir, f"sample_inputs", args.model_name)
os.path.exists(sample_input_path) or os.mkdir(sample_input_path)
# assume the layers are evenly distributed
for i in range(args_parser.num_chunks):
input_list_line = " ".join([f"{args.export_dir}/sample_inputs/{args.model_name}/chunk_{i}/input_{j}.bin" for j in range(3*layers_per_chunk+1)])
os.path.exists(os.path.join(sample_input_path, f"chunk_{i}")) or os.mkdir(os.path.join(sample_input_path, f"chunk_{i}"))
with open(os.path.join(sample_input_path, f"input_list_chunk_{i}.txt"), 'w') as f:
f.write(input_list_line)
if i == 0:
np.zeros((1, 1), dtype=np.int32).tofile(os.path.join(sample_input_path, f"chunk_{i}", "input_0.bin"))
else:
np.zeros((1, 1, model_args.n_embd), dtype=np.float32).tofile(os.path.join(sample_input_path, f"chunk_{i}", "input_0.bin"))
for j in range(layers_per_chunk):
np.zeros((1, 1, model_args.n_embd), dtype=np.float32).tofile(os.path.join(sample_input_path, f"chunk_{i}", f"input_{3*j+1}.bin"))
np.zeros((model_args.n_head, model_args.head_size, model_args.head_size), dtype=np.float32).tofile(os.path.join(sample_input_path, f"chunk_{i}", f"input_{3*j+2}.bin"))
np.zeros((1, 1, model_args.n_embd), dtype=np.float32).tofile(os.path.join(sample_input_path, f"chunk_{i}", f"input_{3*j+3}.bin"))
for i in range(args_parser.num_chunks):
onnx_file = os.path.join(args.export_dir, "split_onnx", f"{args.model_name}_chunk{i+1}of{args_parser.num_chunks}.onnx")
cmd = [f"{qnn_sdk_root}/bin/x86_64-linux-clang/qnn-onnx-converter"]
cmd += ["-i", onnx_file]
cmd += ["--act_bitwidth", "16"]
cmd += ["--bias_bitwidth", "8"]
cmd += ["--float_bitwidth", "32"]
cmd += ["--quantization_overrides", onnx_path.replace('.onnx', '.encodings')]
if args_parser.calib_data_path is not None:
cmd += ["--input_list", os.path.join(args_parser.calib_data_path, f"input_list_chunk{i}.txt")]
else:
cmd += ["--input_list", os.path.join(sample_input_path, f"input_list_chunk_{i}.txt")]
for j in range(i*layers_per_chunk, (i+1)*layers_per_chunk):
for k in range(3):
cmd += ["--input_layout", f"layer{j}_state{k}_in", "NONTRIVIAL"]
print(" ".join(cmd))
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = proc.communicate()
print(error.decode())
cmd = [f"{qnn_sdk_root}/bin/x86_64-linux-clang/qnn-model-lib-generator"]
cmd += ["-c", onnx_file.replace('.onnx', '.cpp')]
cmd += ["-b", onnx_file.replace('.onnx', '.bin')]
print(" ".join(cmd))
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, error = proc.communicate()
print(error.decode())