Skip to content

Commit

Permalink
refactor: tidy
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <[email protected]>
  • Loading branch information
thxCode committed Jul 16, 2024
1 parent f6cb5c9 commit 7287a28
Showing 1 changed file with 69 additions and 67 deletions.
136 changes: 69 additions & 67 deletions llama-box/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ struct server_slot {

void release() {
if (state == SLOT_STATE_PROCESSING) {
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
t_token_generation = double(ggml_time_us() - t_start_generation) / 1e3;
command = SLOT_COMMAND_RELEASE;
}
}
Expand Down Expand Up @@ -307,15 +307,15 @@ struct server_metrics {
void on_prompt_eval(const server_slot &slot) {
n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
t_prompt_processing += slot.t_prompt_processing;
t_prompt_processing_total += slot.t_prompt_processing;
t_prompt_processing += uint64_t(slot.t_prompt_processing);
t_prompt_processing_total += uint64_t(slot.t_prompt_processing);
}

void on_prediction(const server_slot &slot) {
n_tokens_predicted_total += slot.n_decoded;
n_tokens_predicted += slot.n_decoded;
t_tokens_generation += slot.t_token_generation;
t_tokens_generation_total += slot.t_token_generation;
t_tokens_generation += uint64_t(slot.t_token_generation);
t_tokens_generation_total += uint64_t(slot.t_token_generation);
}

void reset_bucket() {
Expand Down Expand Up @@ -692,7 +692,7 @@ struct server_context {
// worst case: there is no information about template, we will use chatml by default
return "chatml"; // see llama_chat_apply_template_internal
}
return std::string(model_template.data(), res);
return {model_template.data(), (unsigned long)(res)};
}

bool init() {
Expand Down Expand Up @@ -857,15 +857,15 @@ struct server_context {
std::string slot_prompt = slot.prompt.get<std::string>();

// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
int slot_prompt_len = int(slot_prompt.size());

// length of the Longest Common Prefix between the current
// slot's prompt and the input prompt
int lcp_len = common_part(slot_prompt, prompt);
int lcp_len = int(common_part(slot_prompt, prompt));

// fraction of the common substring length compared to the
// current slot's prompt length
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
similarity = float(lcp_len) / float(slot_prompt_len);

// select the current slot if the criteria match
if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
Expand Down Expand Up @@ -1393,7 +1393,7 @@ struct server_context {
std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(),
slot.generated_token_probs.end() -
safe_offset);
int(safe_offset));
} else {
probs = std::vector<completion_token_output>(slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
Expand All @@ -1406,7 +1406,7 @@ struct server_context {
queue_results.send(res);
}

void send_embedding(const server_slot &slot, const llama_batch &batch) {
void send_embedding(const server_slot &slot, const llama_batch &batch_view) {
server_task_result res;
res.id = slot.id_task;
res.id_multi = slot.id_multi;
Expand All @@ -1417,19 +1417,19 @@ struct server_context {

std::vector<float> embd_res(n_embd, 0.0f);

for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
for (int i = 0; i < batch_view.n_tokens; ++i) {
if (!batch_view.logits[i] || batch_view.seq_id[i][0] != slot.id + 1) {
continue;
}

const float *embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
const float *embd = llama_get_embeddings_seq(ctx, batch_view.seq_id[i][0]);
if (embd == nullptr) {
embd = llama_get_embeddings_ith(ctx, i);
}

if (embd == NULL) {
if (embd == nullptr) {
LOG_ERROR("failed to get embeddings",
{{"token", batch.token[i]}, {"seq_id", batch.seq_id[i][0]}});
{{"token", batch_view.token[i]}, {"seq_id", batch_view.seq_id[i][0]}});

res.data = json{
{"embedding", std::vector<float>(n_embd, 0.0f)},
Expand Down Expand Up @@ -1501,7 +1501,7 @@ struct server_context {
}

void split_multiprompt_task(int id_multi, const server_task &multiprompt_task) {
const int prompt_count = multiprompt_task.data.at("prompt").size();
const auto prompt_count = int32_t(multiprompt_task.data.at("prompt").size());
if (prompt_count <= 1) {
send_error(multiprompt_task, "error while handling multiple prompts");
return;
Expand Down Expand Up @@ -1669,7 +1669,7 @@ struct server_context {
slot->cache_tokens.data(), token_count);

const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
const double t_save_ms = double(t_end - t_start) / 1000.0;

server_task_result result;
result.id = task.id;
Expand Down Expand Up @@ -1718,7 +1718,7 @@ struct server_context {
slot->cache_tokens.resize(token_count);

const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
const double t_restore_ms = double(t_end - t_start) / 1000.0;

server_task_result result;
result.id = task.id;
Expand Down Expand Up @@ -1779,6 +1779,8 @@ struct server_context {
}

void update_slots() {
const auto n_system_tokens = int32_t(system_tokens.size());

// release slots
for (auto &slot : slots) {
if (slot.command == SLOT_COMMAND_RELEASE) {
Expand All @@ -1790,7 +1792,7 @@ struct server_context {
{"id_task", slot.id_task},
{"n_ctx", n_ctx},
{"n_past", slot.n_past},
{"n_system_tokens", system_tokens.size()},
{"n_system_tokens", n_system_tokens},
{"n_cache_tokens", slot.cache_tokens.size()},
{"truncated", slot.truncated}});

Expand All @@ -1814,7 +1816,7 @@ struct server_context {
llama_kv_cache_clear(ctx);
} else {
for (int32_t i = 0; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_rm(ctx, i, system_tokens.size(), -1);
llama_kv_cache_seq_rm(ctx, i, n_system_tokens, -1);
}
}
return;
Expand Down Expand Up @@ -1853,7 +1855,7 @@ struct server_context {

llama_kv_cache_seq_rm(ctx, slot.id + 1, n_keep, n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard,
system_tokens.size() + slot.n_past, -n_discard);
n_system_tokens + slot.n_past, -n_discard);

if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
Expand Down Expand Up @@ -1889,8 +1891,7 @@ struct server_context {

// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, {slot.id + 1},
true);
llama_batch_add(batch, slot.sampled, n_system_tokens + slot_npast, {slot.id + 1}, true);

slot.n_past += 1;

Expand All @@ -1900,8 +1901,8 @@ struct server_context {
}

// process in chunks of params.n_batch
int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx);
auto n_batch = int32_t(llama_n_batch(ctx));
auto n_ubatch = int32_t(llama_n_ubatch(ctx));

// track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode
Expand Down Expand Up @@ -2016,7 +2017,7 @@ struct server_context {
prompt_tokens = std::move(new_tokens);

slot.truncated = true;
slot.n_prompt_tokens = prompt_tokens.size();
slot.n_prompt_tokens = n_system_tokens;

GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
}
Expand All @@ -2031,7 +2032,8 @@ struct server_context {

// reuse any previously computed tokens that are
// common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
slot.n_past =
int32_t(common_part(slot.cache_tokens, prompt_tokens));

// push the prompt into the sampling context (do
// not apply grammar)
Expand Down Expand Up @@ -2123,7 +2125,7 @@ struct server_context {
}

llama_batch_add(batch, prompt_tokens[slot.n_past],
system_tokens.size() + slot_npast, {slot.id + 1}, false);
n_system_tokens + slot_npast, {slot.id + 1}, false);

if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
Expand Down Expand Up @@ -2284,15 +2286,15 @@ struct server_context {

completion_token_output result;
const llama_token id =
llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
llama_sampling_sample(slot.ctx_sampling, ctx, nullptr, slot.i_batch - i);

llama_sampling_accept(slot.ctx_sampling, ctx, id, true);

slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing =
(slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
double(slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}

Expand All @@ -2307,7 +2309,7 @@ struct server_context {
// Make sure at least n_probs top tokens are at the front of
// the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
llama_sample_top_k(ctx, &cur_p, int32_t(n_probs), 0);
}

if (slot.sparams.temp == 0.0f) {
Expand All @@ -2317,13 +2319,10 @@ struct server_context {
result.probs.push_back({cur_p.data[x].id, x == 0 ? 1.0f : 0.0f});
}
} else {
// Tokens filtered out due to e.g. top_k have 0 probability.
for (size_t x = 0; x < n_probs; ++x) {
result.probs.push_back({
cur_p.data[x].id,
x >= n_valid ? 0.0f
: cur_p.data[x].p // Tokens filtered out due to e.g.
// top_k have 0 probability.
});
result.probs.push_back(
{cur_p.data[x].id, x >= n_valid ? 0.0f : cur_p.data[x].p});
}
}
}
Expand All @@ -2340,6 +2339,8 @@ struct server_context {
}

bool process_vision_prompt(server_slot &slot, int n_batch) {
const auto n_system_tokens = int32_t(system_tokens.size());

const int n_embd = llama_n_embd(model);
const auto sz_i = int32_t(slot.prompt.size());
for (int32_t i = 0; i < sz_i; ++i) {
Expand All @@ -2353,8 +2354,8 @@ struct server_context {
std::vector<llama_token> tokens = tokenize(jp.at("text"), false);
const auto sz_j = int32_t(tokens.size());
for (int32_t j = 0; j < sz_j; ++j) {
llama_batch_add(batch, tokens[j], system_tokens.size() + slot.n_past,
{slot.id + 1}, false);
llama_batch_add(batch, tokens[j], n_system_tokens + slot.n_past, {slot.id + 1},
false);
slot.n_past += 1;
slot.n_prompt_tokens += 1;
slot.n_prompt_tokens_processed += 1;
Expand Down Expand Up @@ -2405,7 +2406,7 @@ struct server_context {
}
const std::vector<uint8_t> buff = base64_decode(img);
llava_image_embed *img_embd = llava_image_embed_make_with_bytes(
ctx_clip, params.n_threads, buff.data(), buff.size());
ctx_clip, params.n_threads, buff.data(), int(buff.size()));
if (!img_embd) {
send_error(slot, "Failed to embed image", ERROR_TYPE_INVALID_REQUEST);
return false;
Expand Down Expand Up @@ -2448,8 +2449,7 @@ struct server_context {
}
const auto sz_j = int32_t(tokens.size());
for (int32_t j = 0; j < sz_j; ++j) {
llama_batch_add(batch, tokens[j], system_tokens.size() + slot.n_past, {slot.id + 1},
false);
llama_batch_add(batch, tokens[j], n_system_tokens + slot.n_past, {slot.id + 1}, false);
slot.n_past += 1;
slot.n_prompt_tokens += 1;
slot.n_prompt_tokens_processed += 1;
Expand Down Expand Up @@ -2554,21 +2554,21 @@ int main(int argc, char **argv) {
res.set_content(final_response.dump(), "application/json; charset=utf-8");
res.status = json_value(error_data, "code", httplib::StatusCode::InternalServerError_500);
};
svr.set_exception_handler(
[&res_error](const httplib::Request &, httplib::Response &res, std::exception_ptr ep) {
std::string message;
try {
std::rethrow_exception(ep);
} catch (std::exception &e) {
message = e.what();
} catch (...) {
message = "Unknown Exception";
}

json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_ERROR("Got exception", formatted_error);
res_error(res, formatted_error);
});
svr.set_exception_handler([&res_error](const httplib::Request &, httplib::Response &res,
const std::exception_ptr &ep) {
std::string message;
try {
std::rethrow_exception(ep);
} catch (std::exception &e) {
message = e.what();
} catch (...) {
message = "Unknown Exception";
}

json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
LOG_ERROR("Got exception", formatted_error);
res_error(res, formatted_error);
});
svr.set_error_handler([&res_error](const httplib::Request &, httplib::Response &res) {
if (res.status == 404) {
res_error(res, format_error_response("Not Found", ERROR_TYPE_NOT_FOUND));
Expand Down Expand Up @@ -2707,23 +2707,24 @@ int main(int argc, char **argv) {
{"value", n_prompt_tokens_processed_total}},
{{"name", "prompt_seconds_total"},
{"help", "Prompt process time"},
{"value", t_prompt_processing_total / 1.e3}},
{"value", double(t_prompt_processing_total) / 1.e3}},
{{"name", "tokens_predicted_total"},
{"help", "Number of generation tokens processed."},
{"value", n_tokens_predicted_total}},
{{"name", "tokens_predicted_seconds_total"},
{"help", "Predict process time"},
{"value", t_tokens_generation_total / 1.e3}}}},
{"value", double(t_tokens_generation_total) / 1.e3}}}},
{"gauge",
{{{"name", "prompt_tokens_seconds"},
{"help", "Average prompt throughput in tokens/s."},
{"value", n_prompt_tokens_processed
? 1.e3 / t_prompt_processing * n_prompt_tokens_processed
? 1.e3 / double(t_prompt_processing * n_prompt_tokens_processed)
: 0.}},
{{"name", "predicted_tokens_seconds"},
{"help", "Average generation throughput in tokens/s."},
{"value",
n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}},
{"value", n_tokens_predicted
? 1.e3 / double(t_tokens_generation * n_tokens_predicted)
: 0.}},
{{"name", "kv_cache_usage_ratio"},
{"help", "KV-cache usage. 1 means 100 percent usage."},
{"value", 1. * kv_cache_used_cells / params.n_ctx}},
Expand Down Expand Up @@ -3403,12 +3404,13 @@ int main(int argc, char **argv) {
shutdown_handler = [&](int) { ctx_server.queue_tasks.terminate(); };

#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
struct sigaction sigint_action;
struct sigaction sigint_action {};

sigint_action.sa_handler = signal_handler;
sigemptyset(&sigint_action.sa_mask);
sigint_action.sa_flags = 0;
sigaction(SIGINT, &sigint_action, NULL);
sigaction(SIGTERM, &sigint_action, NULL);
sigaction(SIGINT, &sigint_action, nullptr);
sigaction(SIGTERM, &sigint_action, nullptr);
#elif defined(_WIN32)
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
Expand Down

0 comments on commit 7287a28

Please sign in to comment.