Skip to content

Commit 89fea80

Browse files
authored
server : fix incorrect usage of llama_get_embeddings() (#14225)
* server : fix incorrect usage of llama_get_embeddings() ggml-ci * cont : fix the fix ggml-ci
1 parent 6adc3c3 commit 89fea80

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ extern "C" {
965965
LLAMA_API int32_t llama_n_threads_batch(struct llama_context * ctx);
966966

967967
// Set whether the context outputs embeddings or not
968+
// TODO: rename to avoid confusion with llama_get_embeddings()
968969
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);
969970

970971
// Set whether to use causal attention or not

tools/server/server.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,14 @@ struct server_slot {
13581358
return server_task_type_need_logits(task_type);
13591359
}
13601360

1361+
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
1362+
// also we cannot split if the pooling would require any past tokens
1363+
bool can_split() const {
1364+
return
1365+
!need_embd() ||
1366+
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1367+
}
1368+
13611369
bool can_batch_with(server_slot & other_slot) const {
13621370
return task_type == other_slot.task_type && are_lora_equal(lora, other_slot.lora);
13631371
}
@@ -1929,14 +1937,6 @@ struct server_context {
19291937
llama_batch_free(batch);
19301938
}
19311939

1932-
// if the context does not have a memory module then all embeddings have to be computed within a single ubatch
1933-
// also we cannot split if the pooling would require any past tokens
1934-
bool can_split() const {
1935-
return
1936-
!llama_get_embeddings(ctx) ||
1937-
(llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST);
1938-
}
1939-
19401940
bool load_model(const common_params & params) {
19411941
SRV_INF("loading model '%s'\n", params.model.path.c_str());
19421942

@@ -3130,7 +3130,7 @@ struct server_context {
31303130
continue;
31313131
}
31323132

3133-
if (!can_split()) {
3133+
if (!slot.can_split()) {
31343134
if (slot.n_prompt_tokens > n_ubatch) {
31353135
slot.release();
31363136
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -3273,7 +3273,7 @@ struct server_context {
32733273
slot.n_prompt_tokens_processed = 0;
32743274
}
32753275

3276-
if (!can_split()) {
3276+
if (!slot.can_split()) {
32773277
// cannot fit the prompt in the current batch - will try next iter
32783278
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
32793279
continue;

0 commit comments

Comments
 (0)