Skip to content

Commit 9484c01

Browse files
authored
Qualcomm AI Engine Direct - Enable AR-N model for prompt processing in hybrid mode (#8210)
* Qualcomm AI Engine Direct - Enable AR-N mode to process prompt in hybrid mode Summary: - Add `max_seq_len` to refer to maximum number of tokens that the model can process & consider at once to generate predictions/responses. - Add `prefill_ar_n` to determine the number of tokens to consume and the number of logits to produce for prompt processor in hybrid mode. - Remove prefill mode * fixed CI * Add the figure to readme and fixed unused variable * fixed linting
1 parent f965746 commit 9484c01

File tree

11 files changed

+769
-466
lines changed

11 files changed

+769
-466
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3154,9 +3154,9 @@ def test_llama3_2_1b(self):
31543154
"llama3_2",
31553155
"--model_mode",
31563156
"hybrid",
3157-
"--prefill_seq_len",
3157+
"--prefill_ar_len",
31583158
"32",
3159-
"--kv_seq_len",
3159+
"--max_seq_len",
31603160
"512",
31613161
"--num_sharding",
31623162
"4",
@@ -3234,9 +3234,9 @@ def test_llama_stories_110m(self):
32343234
"stories110m",
32353235
"--model_mode",
32363236
"hybrid",
3237-
"--prefill_seq_len",
3237+
"--prefill_ar_len",
32383238
"32",
3239-
"--kv_seq_len",
3239+
"--max_seq_len",
32403240
"128",
32413241
]
32423242
if self.compile_only:

examples/qualcomm/oss_scripts/llama/README.md

+13-8
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,16 @@ This file provides you the instructions to run LLAMA model with different parame
88

99
We offer the following modes to execute the model:
1010

11-
Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for encoding the user's prompt.
12-
1311
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
1412

15-
Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
13+
Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
14+
- AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode.
15+
- Prompt processing with AR-N model:
16+
<figure>
17+
<img src="./assets/PromptProcessingWithARN.png" alt="Prompt Processing With AR-N Model">
18+
<figcaption>Prompt processing is done using a for-loop. An N-token block is taken, and the KV cache is updated for that block. This process is repeated until all tokens are consumed, with the last block potentially requiring padding. For flexibility, the AR-N model can handle any input length less than the maximum sequence length. For TTFT, the input length (or number of blocks) will vary depending on the actual input length, rather than always being the same.
19+
</figcaption>
20+
</figure>
1621

1722

1823
## Instructions
@@ -50,13 +55,13 @@ At the end of this step, users should have the following files ready: `consolida
5055
### Step3: Run default examples using hybrid mode.
5156
#### LLAMA2
5257
```bash
53-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "Once upon a time"
58+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "Once upon a time"
5459
```
5560

5661
#### LLAMA3.2
5762
Default example using hybrid mode.
5863
```bash
59-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1"
64+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1"
6065
```
6166

6267
### KV Cache update mechanism
@@ -109,16 +114,16 @@ We have two distinct mechanisms for updating the key-value (KV) cache, which can
109114
### Additional Configs when running the script
110115
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
111116
```bash
112-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --compile_only
117+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --compile_only
113118
```
114119

115120
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
116121
```bash
117-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
122+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
118123
```
119124

120125
You can select the KV Cache update mechanism at runtime by setting the `KV_UPDATER` variable to either "shift_pointer" or "smart_mask". By default, it is set to "smart_mask".
121126
`KV_UPDATER` = "shift_pointer"
122127
```bash
123-
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
128+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
124129
```
Loading

0 commit comments

Comments
 (0)