Skip to content

Commit 4796da7

Browse files
authored
Qualcomm AI Engine Direct - Add smart mask kv updator for llama3.2
Differential Revision: D68398036 Pull Request resolved: #7694
1 parent 282c137 commit 4796da7

File tree

11 files changed

+1332
-566
lines changed

11 files changed

+1332
-566
lines changed

backends/qualcomm/runtime/QnnManager.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,9 @@ Error QnnManager::RegisterMem(
154154
const std::shared_ptr<TensorWrapper>& tensor_wrapper) {
155155
SharedBuffer& shared_buffer_manager = SharedBuffer::GetSharedBufferManager();
156156
// Not enable shared buffer
157-
if (!options_->shared_buffer())
157+
if (!options_->shared_buffer()) {
158158
return Error::Internal;
159+
}
159160

160161
if (backend_params_ptr_->qnn_mem_manager_ptr_ == nullptr) {
161162
QNN_EXECUTORCH_LOG_WARN(

backends/qualcomm/runtime/QnnManager.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ class QnnManager {
145145
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8,
146146
executorch::aten::ScalarType::Byte},
147147
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16,
148-
executorch::aten::ScalarType::Bits16},
148+
executorch::aten::ScalarType::UInt16},
149149
};
150150
};
151151
} // namespace qnn

backends/qualcomm/runtime/backends/QnnMemManager.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class QnnMemManager {
7777
Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16},
7878
{executorch::aten::ScalarType::Byte,
7979
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8},
80-
{executorch::aten::ScalarType::Bits16,
80+
{executorch::aten::ScalarType::UInt16,
8181
Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16},
8282
};
8383
};

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ list(
2828
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
31-
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
32-
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
31+
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.cpp
32+
${CMAKE_CURRENT_LIST_DIR}/runner/io_manager.h
3333
)
3434

3535
list(

examples/qualcomm/oss_scripts/llama/llama.py

+58-13
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,42 @@
7272
logging.getLogger().setLevel(logging.INFO)
7373

7474

75+
def smart_mask_updator(atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches):
76+
for i, k_cache in enumerate(k_caches):
77+
k_cache[:, :, pos] = new_k_caches[i][:, :, 0]
78+
79+
for i, v_cache in enumerate(v_caches):
80+
v_cache[:, pos, :] = new_v_caches[i]
81+
82+
atten_mask[0][pos] = 0
83+
pos += 1
84+
return (atten_mask, pos, k_caches, v_caches)
85+
86+
87+
def shift_pointer_updator(
88+
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
89+
):
90+
k_caches = [
91+
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
92+
for i, k_cache in enumerate(k_caches)
93+
]
94+
v_caches = [
95+
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
96+
for i, v_cache in enumerate(v_caches)
97+
]
98+
99+
pos += 1
100+
atten_mask[0][-pos - 1] = 0
101+
return (atten_mask, pos, k_caches, v_caches)
102+
103+
75104
def _kv_calibrate(
76105
example_inputs,
77106
user_prompts,
78107
module: torch.fx.GraphModule,
79108
tokenizer,
80109
max_seq_len=512,
110+
updator=smart_mask_updator,
81111
):
82112
_, atten_mask, _, k_caches, v_caches = example_inputs
83113

@@ -105,17 +135,9 @@ def _kv_calibrate(
105135
*k_caches,
106136
*v_caches,
107137
)
108-
k_caches = [
109-
torch.cat([k_cache[:, :, 1:], new_k_caches[i]], dim=-1)
110-
for i, k_cache in enumerate(k_caches)
111-
]
112-
v_caches = [
113-
torch.cat([v_cache[:, 1:, :], new_v_caches[i]], dim=1)
114-
for i, v_cache in enumerate(v_caches)
115-
]
116-
117-
pos += 1
118-
atten_mask[0][-pos - 1] = 0
138+
atten_mask, pos, k_caches, v_caches = updator(
139+
atten_mask, pos, k_caches, v_caches, new_k_caches, new_v_caches
140+
)
119141
if pos >= len(token_list):
120142
token_list.append(torch.argmax(logits[:, -1], dim=-1).item())
121143

@@ -174,6 +196,7 @@ def calibrate(
174196
module: torch.fx.GraphModule,
175197
tokenizer,
176198
max_seq_len=512,
199+
kv_updator=smart_mask_updator,
177200
):
178201
if len(example_inputs) == 2:
179202
_prefill_calibrate(
@@ -190,6 +213,7 @@ def calibrate(
190213
module,
191214
tokenizer,
192215
max_seq_len,
216+
updator=kv_updator,
193217
)
194218
else:
195219
raise RuntimeError("Get wrong inputs")
@@ -319,13 +343,15 @@ def quantize(self, quant_dtype, args, tokenizer, custom_annotations=()):
319343
self.llama_model, self.inputs, strict=True
320344
).module()
321345
fx_graph_module = prepare_pt2e(fx_graph_module, quantizer)
346+
322347
logging.info("Quantizing the model...")
323348
calibrate(
324349
self.get_example_inputs(self.llama_meta["get_use_kv_cache"]),
325350
args.prompt,
326351
fx_graph_module,
327352
tokenizer=tokenizer,
328353
max_seq_len=self.llama_meta["get_max_seq_len"],
354+
kv_updator=args.kv_updator,
329355
)
330356

331357
self.llama_model = convert_pt2e(fx_graph_module)
@@ -337,6 +363,7 @@ def lowering_modules(
337363
use_fp16=False,
338364
soc_model=QcomChipset.SM8650,
339365
num_sharding=0,
366+
shared_buffer=False,
340367
):
341368
executorch_config = ExecutorchBackendConfig(
342369
# For shared buffer, user must pass the memory address
@@ -357,7 +384,7 @@ def lowering_modules(
357384
compiler_specs = generate_qnn_executorch_compiler_spec(
358385
soc_model=soc_model,
359386
backend_options=backend_options,
360-
shared_buffer=False,
387+
shared_buffer=shared_buffer,
361388
)
362389
skip_node_op_set = {"llama.fallback.default"}
363390
partitioner = QnnPartitioner(
@@ -530,6 +557,7 @@ def compile(args, pte_filename, tokenizer):
530557
use_fp16=use_fp16,
531558
soc_model=get_soc_to_chipset_map()[args.model],
532559
num_sharding=args.num_sharding,
560+
shared_buffer=args.shared_buffer,
533561
)
534562
quant_attrs = llama_instance_list[0].get_quant_attrs()
535563
else:
@@ -564,7 +592,7 @@ def compile(args, pte_filename, tokenizer):
564592
generate_qnn_executorch_compiler_spec(
565593
soc_model=get_soc_to_chipset_map()[args.model],
566594
backend_options=backend_options,
567-
shared_buffer=True,
595+
shared_buffer=args.shared_buffer,
568596
multiple_graphs=True,
569597
graph_name=graph_name,
570598
)
@@ -736,6 +764,7 @@ def inference(args, quant_attrs, pte_filename, runtime_tokenizer_path, pre_gen_p
736764
f"--system_prompt '{args.system_prompt}'",
737765
f"--logits_scale {quant_attrs['scale']}",
738766
f"--logits_offset {quant_attrs['zero_point']}",
767+
f"--kv_updator {'SmartMask' if args.kv_updator == smart_mask_updator else 'ShiftPointer'}",
739768
]
740769
)
741770
runner_cmd = " ".join(
@@ -907,6 +936,14 @@ def main():
907936
type=int,
908937
)
909938

939+
parser.add_argument(
940+
"--kv_updator",
941+
help="Choose how to update kv cache during runtime",
942+
choices=["smart_mask", "shift_pointer"],
943+
default="smart_mask",
944+
type=str,
945+
)
946+
910947
args = parser.parse_args()
911948
if args.compile_only and args.pre_gen_pte:
912949
exit("Cannot set both compile_only and pre_gen_pte as true")
@@ -941,6 +978,14 @@ def main():
941978
else:
942979
raise RuntimeError(f"Unknown llama_model: {args.llama_model}.")
943980

981+
if args.kv_updator == "smart_mask":
982+
args.shared_buffer = True
983+
args.kv_updator = smart_mask_updator
984+
elif args.kv_updator == "shift_pointer":
985+
args.kv_updator = shift_pointer_updator
986+
else:
987+
exit(f"Using an unkown kv update {args.kv_updator}")
988+
944989
if args.pre_gen_pte:
945990
quant_attrs = json.load(
946991
open(f"{args.pre_gen_pte}/{pte_filename}_quant_attrs.txt")

examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp

+6-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ DEFINE_int32(
5151
"0: PromptProcessor(prefill) / 1: TokenGenerator(kv) / 2: HybridMode (prefill+kv)");
5252
DEFINE_double(logits_scale, 0.0, "Logits scale");
5353
DEFINE_int32(logits_offset, 0, "Logits offset");
54+
DEFINE_string(
55+
kv_updator,
56+
"How to update kv cache. Choose between SmartMask and ShiftPointer",
57+
"SmartMask");
5458

5559
int main(int argc, char** argv) {
5660
gflags::ParseCommandLineFlags(&argc, &argv, true);
@@ -62,7 +66,8 @@ int main(int argc, char** argv) {
6266
FLAGS_logits_scale,
6367
FLAGS_logits_offset,
6468
FLAGS_temperature,
65-
FLAGS_eval_mode);
69+
FLAGS_eval_mode,
70+
FLAGS_kv_updator);
6671
std::vector<char> buf;
6772
buf.reserve(5 * FLAGS_seq_len); // assume each token is around 5 char
6873
std::ofstream fout(FLAGS_output_path.c_str());

0 commit comments

Comments
 (0)