Skip to content

main : add new feature: special commands #10145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.simple_io = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL}));
add_opt(common_arg(
{"-nsc", "--no-special-command"},
string_format("disable special commands in conversation mode (default: %s)", params.special_cmds ? "enabled" : "disabled"),
[](common_params & params) {
params.special_cmds = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-ld", "--logdir"}, "LOGDIR",
"path under which to save YAML logs (no logging if unset)",
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ struct common_params {
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
bool prompt_cache_all = false; // save user input and generations to prompt cache
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
bool special_cmds = true; // enable special commands in main example

bool escape = true; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
bool multiline_input = false; // reverse the usage of `\`
Expand Down
127 changes: 126 additions & 1 deletion examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,22 @@ static std::vector<llama_token> * g_output_tokens;
static bool is_interacting = false;
static bool need_insert_eot = false;

static const char * help_special_cmds = "special commands in conversation mode:\n"
" /readfile FILE read prompt from file\n"
" /savesess FILE save session to file\n"
" /loadsess FILE load session from file\n"
" /regen regenerate the last response\n"
" /dump FILE dump chat content to a file\n";

static void print_usage(int argc, char ** argv) {
(void) argc;

LOG("\nexample usage:\n");
LOG("\n text generation: %s -m your_model.gguf -p \"I believe the meaning of life is\" -n 128\n", argv[0]);
LOG("\n chat (conversation): %s -m your_model.gguf -p \"You are a helpful assistant\" -cnv\n", argv[0]);
LOG("\n");
LOG("%s", help_special_cmds);
LOG("\n");
}

static bool file_exists(const std::string & path) {
Expand Down Expand Up @@ -105,6 +114,21 @@ static void write_logfile(
fclose(logfile);
}

static std::vector<std::string> try_parse_command(std::string text) {
if (text.empty() || text[0] != '/') {
return {};
}
std::vector<std::string> elem = string_split<std::string>(text, ' ');
std::vector<std::string> res;
// filter empty strings
for (const auto & e : elem) {
if (!e.empty()) {
res.push_back(string_strip(e));
}
}
return res;
}

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
static void sigint_handler(int signo) {
if (signo == SIGINT) {
Expand All @@ -127,7 +151,11 @@ static void sigint_handler(int signo) {
}
#endif

// return the formatted turn to be decoded
static std::string chat_add_and_format(struct llama_model * model, std::vector<common_chat_msg> & chat_msgs, const std::string & role, const std::string & content) {
if (content.empty()) {
return "";
}
common_chat_msg new_msg{role, content};
auto formatted = common_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
chat_msgs.push_back({role, content});
Expand Down Expand Up @@ -189,6 +217,7 @@ int main(int argc, char ** argv) {
llama_context * ctx = nullptr;
common_sampler * smpl = nullptr;

std::vector<int> pos_history; // history of positions of chat messages
std::vector<common_chat_msg> chat_msgs;

g_model = &model;
Expand Down Expand Up @@ -515,6 +544,7 @@ int main(int argc, char ** argv) {
display = params.display_prompt;

std::vector<llama_token> embd;
llama_batch batch = llama_batch_init(params.n_batch, 0, 1);

// tokenized antiprompts
std::vector<std::vector<llama_token>> antiprompt_ids;
Expand Down Expand Up @@ -542,6 +572,8 @@ int main(int argc, char ** argv) {
embd_inp.push_back(decoder_start_token_id);
}

std::stringstream pending_input; // used by "/readfile" command

while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
Expand Down Expand Up @@ -648,7 +680,19 @@ int main(int argc, char ** argv) {

LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
common_batch_clear(batch);
for (int j = 0; j < n_eval; j++) {
int idx = i + j;
common_batch_add(
batch,
embd[idx],
n_past + idx,
{0},
idx == (int) embd.size() - 1
);
}

if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
Expand Down Expand Up @@ -851,6 +895,84 @@ int main(int argc, char ** argv) {

LOG_DBG("buffer: '%s'\n", buffer.c_str());

// check for special commands
const std::vector<std::string> cmd = params.special_cmds
? try_parse_command(buffer)
: std::vector<std::string>();

if (cmd.size() == 2 && cmd[0] == "/readfile") {
const std::string filename = cmd[1];
LOG_DBG("reading file: '%s'\n", filename.c_str());
std::ifstream text_file(filename);
if (!text_file) {
LOG("failed to open file '%s'\n", filename.c_str());
continue;
}
pending_input << text_file.rdbuf() << "\n\n";
LOG("read %zu characters from file\n", (size_t) text_file.tellg());
continue;
} else if (cmd.size() == 2 && cmd[0] == "/savesess") {
const std::string filename = cmd[1];
LOG("save session file: '%s'\n", filename.c_str());
size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past);
if (res == 0) {
LOG("failed to save session file '%s'\n", filename.c_str());
}
continue;
} else if (cmd.size() == 2 && cmd[0] == "/loadsess") {
const std::string filename = cmd[1];
LOG("load session file: '%s'\n", filename.c_str());
session_tokens.resize(n_ctx);
size_t n_token_count_out;
size_t res = llama_state_load_file(ctx, filename.c_str(), session_tokens.data(), session_tokens.size(), &n_token_count_out);
if (res == 0) {
LOG("failed to load session file '%s'\n", filename.c_str());
} else {
session_tokens.resize(n_token_count_out);
embd_inp = session_tokens;
n_past = n_token_count_out;
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
LOG("loaded %zu tokens from session file '%s'\n", n_token_count_out, filename.c_str());
}
continue;
} else if (cmd.size() == 1 && cmd[0] == "/regen") {
if (pos_history.empty()) {
LOG("no previous assistant message to regenerate\n");
continue;
}
int last_n_past = pos_history.back();
int n_tokens_removed = n_past - last_n_past;
llama_kv_cache_seq_rm(ctx, 0, last_n_past, -1);
n_remain += n_tokens_removed;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might become a problem with n_predict == -1 (infinite) or -2 (stop at context size), comment from here.

is_interacting = false;
// we intentionally do not reset the sampling, so new message will be more diverse
continue;
Copy link
Contributor

@MaggotHATE MaggotHATE Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember having problems with regeneration without adjusting prev. I ended up capturing the entire ring buffer and just restoring it. I did the same to all other values, capturing state instead of calculating. So far I have stable results over 20 regenerations (my usual bulk testing) with the following:

        restore_smpl(); // rewinds gsmpl->prev
        llama_kv_cache_seq_rm(ctx, 0, rewind_state.kv_cache_pos, -1);
        embd_inp.erase(embd_inp.begin() + rewind_state.embd_inp_size, embd_inp.end()); // not sure
        n_past = rewind_state.n_past_size;
        n_consumed = rewind_state.n_consumed_size;

However, I have to say that there is something missing/extra here: when testing K-Shift sampler I noticed that initial logits of the first message are different from all later regenerations (those are the same, though). Still not sure what's wrong, but maybe it helps.

} else if (cmd.size() == 2 && cmd[0] == "/dump") {
const std::string filename = cmd[1];
std::ofstream dump_file(filename);
if (!dump_file) {
LOG("failed to create file '%s'\n", filename.c_str());
continue;
}
for (const auto & msg : chat_msgs) {
dump_file << msg.role << ":\n" << msg.content << "\n---\n";
}
dump_file.close();
LOG("dumped chat messages to file '%s'\n", filename.c_str());
continue;
} else if (!cmd.empty()) {
LOG("unknown command: %s\n", buffer.c_str());
LOG("%s", help_special_cmds);
continue;
}

if (pending_input.tellp() > 0) {
// concatenate read file and the prompt
pending_input << buffer;
buffer = pending_input.str();
pending_input.clear();
}

const size_t original_size = embd_inp.size();

if (params.escape) {
Expand Down Expand Up @@ -885,6 +1007,8 @@ int main(int argc, char ** argv) {
output_ss << common_token_to_piece(ctx, token);
}

pos_history.push_back(n_past + embd_inp.size() - original_size);

// reset assistant message
assistant_ss.str("");

Expand Down Expand Up @@ -930,6 +1054,7 @@ int main(int argc, char ** argv) {

common_sampler_free(smpl);

llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);

Expand Down
Loading