72
72
logging .getLogger ().setLevel (logging .INFO )
73
73
74
74
75
+ def smart_mask_updator (atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches ):
76
+ for i , k_cache in enumerate (k_caches ):
77
+ k_cache [:, :, pos ] = new_k_caches [i ][:, :, 0 ]
78
+
79
+ for i , v_cache in enumerate (v_caches ):
80
+ v_cache [:, pos , :] = new_v_caches [i ]
81
+
82
+ atten_mask [0 ][pos ] = 0
83
+ pos += 1
84
+ return (atten_mask , pos , k_caches , v_caches )
85
+
86
+
87
+ def shift_pointer_updator (
88
+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
89
+ ):
90
+ k_caches = [
91
+ torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
92
+ for i , k_cache in enumerate (k_caches )
93
+ ]
94
+ v_caches = [
95
+ torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
96
+ for i , v_cache in enumerate (v_caches )
97
+ ]
98
+
99
+ pos += 1
100
+ atten_mask [0 ][- pos - 1 ] = 0
101
+ return (atten_mask , pos , k_caches , v_caches )
102
+
103
+
75
104
def _kv_calibrate (
76
105
example_inputs ,
77
106
user_prompts ,
78
107
module : torch .fx .GraphModule ,
79
108
tokenizer ,
80
109
max_seq_len = 512 ,
110
+ updator = smart_mask_updator ,
81
111
):
82
112
_ , atten_mask , _ , k_caches , v_caches = example_inputs
83
113
@@ -105,17 +135,9 @@ def _kv_calibrate(
105
135
* k_caches ,
106
136
* v_caches ,
107
137
)
108
- k_caches = [
109
- torch .cat ([k_cache [:, :, 1 :], new_k_caches [i ]], dim = - 1 )
110
- for i , k_cache in enumerate (k_caches )
111
- ]
112
- v_caches = [
113
- torch .cat ([v_cache [:, 1 :, :], new_v_caches [i ]], dim = 1 )
114
- for i , v_cache in enumerate (v_caches )
115
- ]
116
-
117
- pos += 1
118
- atten_mask [0 ][- pos - 1 ] = 0
138
+ atten_mask , pos , k_caches , v_caches = updator (
139
+ atten_mask , pos , k_caches , v_caches , new_k_caches , new_v_caches
140
+ )
119
141
if pos >= len (token_list ):
120
142
token_list .append (torch .argmax (logits [:, - 1 ], dim = - 1 ).item ())
121
143
@@ -174,6 +196,7 @@ def calibrate(
174
196
module : torch .fx .GraphModule ,
175
197
tokenizer ,
176
198
max_seq_len = 512 ,
199
+ kv_updator = smart_mask_updator ,
177
200
):
178
201
if len (example_inputs ) == 2 :
179
202
_prefill_calibrate (
@@ -190,6 +213,7 @@ def calibrate(
190
213
module ,
191
214
tokenizer ,
192
215
max_seq_len ,
216
+ updator = kv_updator ,
193
217
)
194
218
else :
195
219
raise RuntimeError ("Get wrong inputs" )
@@ -319,13 +343,15 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
319
343
self .llama_model , self .inputs , strict = True
320
344
).module ()
321
345
fx_graph_module = prepare_pt2e (fx_graph_module , quantizer )
346
+
322
347
logging .info ("Quantizing the model..." )
323
348
calibrate (
324
349
self .get_example_inputs (self .llama_meta ["get_use_kv_cache" ]),
325
350
args .prompt ,
326
351
fx_graph_module ,
327
352
tokenizer = tokenizer ,
328
353
max_seq_len = self .llama_meta ["get_max_seq_len" ],
354
+ kv_updator = args .kv_updator ,
329
355
)
330
356
331
357
self .llama_model = convert_pt2e (fx_graph_module )
@@ -337,6 +363,7 @@ def lowering_modules(
337
363
use_fp16 = False ,
338
364
soc_model = QcomChipset .SM8650 ,
339
365
num_sharding = 0 ,
366
+ shared_buffer = False ,
340
367
):
341
368
executorch_config = ExecutorchBackendConfig (
342
369
# For shared buffer, user must pass the memory address
@@ -357,7 +384,7 @@ def lowering_modules(
357
384
compiler_specs = generate_qnn_executorch_compiler_spec (
358
385
soc_model = soc_model ,
359
386
backend_options = backend_options ,
360
- shared_buffer = False ,
387
+ shared_buffer = shared_buffer ,
361
388
)
362
389
skip_node_op_set = {"llama.fallback.default" }
363
390
partitioner = QnnPartitioner (
@@ -530,6 +557,7 @@ def compile(args, pte_filename, tokenizer):
530
557
use_fp16 = use_fp16 ,
531
558
soc_model = get_soc_to_chipset_map ()[args .model ],
532
559
num_sharding = args .num_sharding ,
560
+ shared_buffer = args .shared_buffer ,
533
561
)
534
562
quant_attrs = llama_instance_list [0 ].get_quant_attrs ()
535
563
else :
@@ -564,7 +592,7 @@ def compile(args, pte_filename, tokenizer):
564
592
generate_qnn_executorch_compiler_spec (
565
593
soc_model = get_soc_to_chipset_map ()[args .model ],
566
594
backend_options = backend_options ,
567
- shared_buffer = True ,
595
+ shared_buffer = args . shared_buffer ,
568
596
multiple_graphs = True ,
569
597
graph_name = graph_name ,
570
598
)
@@ -736,6 +764,7 @@ def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_p
736
764
f"--system_prompt '{ args .system_prompt } '" ,
737
765
f"--logits_scale { quant_attrs ['scale' ]} " ,
738
766
f"--logits_offset { quant_attrs ['zero_point' ]} " ,
767
+ f"--kv_updator { 'SmartMask' if args .kv_updator == smart_mask_updator else 'ShiftPointer' } " ,
739
768
]
740
769
)
741
770
runner_cmd = " " .join (
@@ -907,6 +936,14 @@ def main():
907
936
type = int ,
908
937
)
909
938
939
+ parser .add_argument (
940
+ "--kv_updator" ,
941
+ help = "Choose how to update kv cache during runtime" ,
942
+ choices = ["smart_mask" , "shift_pointer" ],
943
+ default = "smart_mask" ,
944
+ type = str ,
945
+ )
946
+
910
947
args = parser .parse_args ()
911
948
if args .compile_only and args .pre_gen_pte :
912
949
exit ("Cannot set both compile_only and pre_gen_pte as true" )
@@ -941,6 +978,14 @@ def main():
941
978
else :
942
979
raise RuntimeError (f"Unknown llama_model: { args .llama_model } ." )
943
980
981
+ if args .kv_updator == "smart_mask" :
982
+ args .shared_buffer = True
983
+ args .kv_updator = smart_mask_updator
984
+ elif args .kv_updator == "shift_pointer" :
985
+ args .kv_updator = shift_pointer_updator
986
+ else :
987
+ exit (f"Using an unkown kv update { args .kv_updator } " )
988
+
944
989
if args .pre_gen_pte :
945
990
quant_attrs = json .load (
946
991
open (f"{ args .pre_gen_pte } /{ pte_filename } _quant_attrs.txt" )
0 commit comments