Skip to content

Commit d67341d

Browse files
authored
server : add server parameters for draft model cache type (#13782)
Co-authored-by: aa956 <[email protected]>
1 parent 456af35 commit d67341d

File tree

4 files changed

+33
-4
lines changed

4 files changed

+33
-4
lines changed

common/arg.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3210,6 +3210,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32103210
params.speculative.model.path = value;
32113211
}
32123212
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));
3213+
add_opt(common_arg(
3214+
{"-ctkd", "--cache-type-k-draft"}, "TYPE",
3215+
string_format(
3216+
"KV cache data type for K for the draft model\n"
3217+
"allowed values: %s\n"
3218+
"(default: %s)",
3219+
get_all_kv_cache_types().c_str(),
3220+
ggml_type_name(params.speculative.cache_type_k)
3221+
),
3222+
[](common_params & params, const std::string & value) {
3223+
params.speculative.cache_type_k = kv_cache_type_from_str(value);
3224+
}
3225+
).set_env("LLAMA_ARG_CACHE_TYPE_K_DRAFT"));
3226+
add_opt(common_arg(
3227+
{"-ctvd", "--cache-type-v-draft"}, "TYPE",
3228+
string_format(
3229+
"KV cache data type for V for the draft model\n"
3230+
"allowed values: %s\n"
3231+
"(default: %s)",
3232+
get_all_kv_cache_types().c_str(),
3233+
ggml_type_name(params.speculative.cache_type_v)
3234+
),
3235+
[](common_params & params, const std::string & value) {
3236+
params.speculative.cache_type_v = kv_cache_type_from_str(value);
3237+
}
3238+
).set_env("LLAMA_ARG_CACHE_TYPE_V_DRAFT"));
32133239

32143240
add_opt(common_arg(
32153241
{"-mv", "--model-vocoder"}, "FNAME",

common/common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ struct common_params_speculative {
199199
float p_split = 0.1f; // speculative decoding split probability
200200
float p_min = 0.75f; // minimum speculative decoding probability (greedy)
201201

202+
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
203+
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
204+
202205
struct cpu_params cpuparams;
203206
struct cpu_params cpuparams_batch;
204207

tools/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,8 @@ The project is under active development, and we are [looking for feedback and co
187187
| `-devd, --device-draft <dev1,dev2,..>` | comma-separated list of devices to use for offloading the draft model (none = don't offload)<br/>use --list-devices to see a list of available devices |
188188
| `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | number of layers to store in VRAM for the draft model<br/>(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) |
189189
| `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)<br/>(env: LLAMA_ARG_MODEL_DRAFT) |
190+
| `-ctkd, --cache-type-k-draft TYPE` | KV cache data type for K for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K_DRAFT) |
191+
| `-ctvd, --cache-type-v-draft TYPE` | KV cache data type for V for speculative decoding model<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V_DRAFT) |
190192
| `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) |
191193
| `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall |
192194
| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) |

tools/server/server.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,10 +1969,8 @@ struct server_context {
19691969
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
19701970
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
19711971
params_dft.n_parallel = 1;
1972-
1973-
// force F16 KV cache for the draft model for extra performance
1974-
params_dft.cache_type_k = GGML_TYPE_F16;
1975-
params_dft.cache_type_v = GGML_TYPE_F16;
1972+
params_dft.cache_type_k = params_base.speculative.cache_type_k;
1973+
params_dft.cache_type_v = params_base.speculative.cache_type_v;
19761974

19771975
llama_init_dft = common_init_from_params(params_dft);
19781976

0 commit comments

Comments
 (0)