-
Notifications
You must be signed in to change notification settings - Fork 12.2k
main : add new feature: special commands #10145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,13 +41,22 @@ static std::vector<llama_token> * g_output_tokens; | |
static bool is_interacting = false; | ||
static bool need_insert_eot = false; | ||
|
||
static const char * help_special_cmds = "special commands in conversation mode:\n" | ||
" /readfile FILE read prompt from file\n" | ||
" /savesess FILE save session to file\n" | ||
" /loadsess FILE load session from file\n" | ||
" /regen regenerate the last response\n" | ||
" /dump FILE dump chat content to a file\n"; | ||
|
||
static void print_usage(int argc, char ** argv) { | ||
(void) argc; | ||
|
||
LOG("\nexample usage:\n"); | ||
LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]); | ||
LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]); | ||
LOG("\n"); | ||
LOG("%s", help_special_cmds); | ||
LOG("\n"); | ||
} | ||
|
||
static bool file_exists(const std::string & path) { | ||
|
@@ -105,6 +114,21 @@ static void write_logfile( | |
fclose(logfile); | ||
} | ||
|
||
static std::vector<std::string> try_parse_command(std::string text) { | ||
if (text.empty() || text[0] != '/') { | ||
return {}; | ||
} | ||
std::vector<std::string> elem = string_split<std::string>(text, ' '); | ||
std::vector<std::string> res; | ||
// filter empty strings | ||
for (const auto & e : elem) { | ||
if (!e.empty()) { | ||
res.push_back(string_strip(e)); | ||
} | ||
} | ||
return res; | ||
} | ||
|
||
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) | ||
static void sigint_handler(int signo) { | ||
if (signo == SIGINT) { | ||
|
@@ -127,7 +151,11 @@ static void sigint_handler(int signo) { | |
} | ||
#endif | ||
|
||
// return the formatted turn to be decoded | ||
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) { | ||
if (content.empty()) { | ||
return ""; | ||
} | ||
common_chat_msg new_msg{role, content}; | ||
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); | ||
chat_msgs.push_back({role, content}); | ||
|
@@ -189,6 +217,7 @@ int main(int argc, char ** argv) { | |
llama_context * ctx = nullptr; | ||
common_sampler * smpl = nullptr; | ||
|
||
std::vector<int> pos_history; // history of positions of chat messages | ||
std::vector<common_chat_msg> chat_msgs; | ||
|
||
g_model = &model; | ||
|
@@ -515,6 +544,7 @@ int main(int argc, char ** argv) { | |
display = params.display_prompt; | ||
|
||
std::vector<llama_token> embd; | ||
llama_batch batch = llama_batch_init(params.n_batch, 0, 1); | ||
|
||
// tokenized antiprompts | ||
std::vector<std::vector<llama_token>> antiprompt_ids; | ||
|
@@ -542,6 +572,8 @@ int main(int argc, char ** argv) { | |
embd_inp.push_back(decoder_start_token_id); | ||
} | ||
|
||
std::stringstream pending_input; // used by "/readfile" command | ||
|
||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) { | ||
// predict | ||
if (!embd.empty()) { | ||
|
@@ -648,7 +680,19 @@ int main(int argc, char ** argv) { | |
|
||
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); | ||
|
||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { | ||
common_batch_clear(batch); | ||
for (int j = 0; j < n_eval; j++) { | ||
int idx = i + j; | ||
common_batch_add( | ||
batch, | ||
embd[idx], | ||
n_past + idx, | ||
{0}, | ||
idx == (int) embd.size() - 1 | ||
); | ||
} | ||
|
||
if (llama_decode(ctx, batch)) { | ||
LOG_ERR("%s : failed to eval\n", __func__); | ||
return 1; | ||
} | ||
|
@@ -851,6 +895,84 @@ int main(int argc, char ** argv) { | |
|
||
LOG_DBG("buffer: '%s'\n", buffer.c_str()); | ||
|
||
// check for special commands | ||
const std::vector<std::string> cmd = params.special_cmds | ||
? try_parse_command(buffer) | ||
: std::vector<std::string>(); | ||
|
||
if (cmd.size() == 2 && cmd[0] == "/readfile") { | ||
const std::string filename = cmd[1]; | ||
LOG_DBG("reading file: '%s'\n", filename.c_str()); | ||
std::ifstream text_file(filename); | ||
if (!text_file) { | ||
LOG("failed to open file '%s'\n", filename.c_str()); | ||
continue; | ||
} | ||
pending_input << text_file.rdbuf() << "\n\n"; | ||
LOG("read %zu characters from file\n", (size_t) text_file.tellg()); | ||
continue; | ||
} else if (cmd.size() == 2 && cmd[0] == "/savesess") { | ||
const std::string filename = cmd[1]; | ||
LOG("save session file: '%s'\n", filename.c_str()); | ||
size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past); | ||
if (res == 0) { | ||
LOG("failed to save session file '%s'\n", filename.c_str()); | ||
} | ||
continue; | ||
} else if (cmd.size() == 2 && cmd[0] == "/loadsess") { | ||
const std::string filename = cmd[1]; | ||
LOG("load session file: '%s'\n", filename.c_str()); | ||
session_tokens.resize(n_ctx); | ||
size_t n_token_count_out; | ||
size_t res = llama_state_load_file(ctx, filename.c_str(), session_tokens.data(), session_tokens.size(), &n_token_count_out); | ||
if (res == 0) { | ||
LOG("failed to load session file '%s'\n", filename.c_str()); | ||
} else { | ||
session_tokens.resize(n_token_count_out); | ||
embd_inp = session_tokens; | ||
n_past = n_token_count_out; | ||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1); | ||
LOG("loaded %zu tokens from session file '%s'\n", n_token_count_out, filename.c_str()); | ||
} | ||
continue; | ||
} else if (cmd.size() == 1 && cmd[0] == "/regen") { | ||
if (pos_history.empty()) { | ||
LOG("no previous assistant message to regenerate\n"); | ||
continue; | ||
} | ||
int last_n_past = pos_history.back(); | ||
int n_tokens_removed = n_past - last_n_past; | ||
llama_kv_cache_seq_rm(ctx, 0, last_n_past, -1); | ||
n_remain += n_tokens_removed; | ||
is_interacting = false; | ||
// we intentionally do not reset the sampling, so new message will be more diverse | ||
continue; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember having problems with regeneration without adjusting
However, I have to say that there is something missing/extra here: when testing |
||
} else if (cmd.size() == 2 && cmd[0] == "/dump") { | ||
const std::string filename = cmd[1]; | ||
std::ofstream dump_file(filename); | ||
if (!dump_file) { | ||
LOG("failed to create file '%s'\n", filename.c_str()); | ||
continue; | ||
} | ||
for (const auto & msg : chat_msgs) { | ||
dump_file << msg.role << ":\n" << msg.content << "\n---\n"; | ||
} | ||
dump_file.close(); | ||
LOG("dumped chat messages to file '%s'\n", filename.c_str()); | ||
continue; | ||
} else if (!cmd.empty()) { | ||
LOG("unknown command: %s\n", buffer.c_str()); | ||
LOG("%s", help_special_cmds); | ||
continue; | ||
} | ||
|
||
if (pending_input.tellp() > 0) { | ||
// concatenate read file and the prompt | ||
pending_input << buffer; | ||
buffer = pending_input.str(); | ||
pending_input.clear(); | ||
} | ||
|
||
const size_t original_size = embd_inp.size(); | ||
|
||
if (params.escape) { | ||
|
@@ -885,6 +1007,8 @@ int main(int argc, char ** argv) { | |
output_ss << common_token_to_piece(ctx, token); | ||
} | ||
|
||
pos_history.push_back(n_past + embd_inp.size() - original_size); | ||
|
||
// reset assistant message | ||
assistant_ss.str(""); | ||
|
||
|
@@ -930,6 +1054,7 @@ int main(int argc, char ** argv) { | |
|
||
common_sampler_free(smpl); | ||
|
||
llama_batch_free(batch); | ||
llama_free(ctx); | ||
llama_free_model(model); | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might become a problem with
n_predict == -1 (infinite) or -2 (stop at context size)
, comment from here.