diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index 0839cad3ee6db..5d317f4ee62eb 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -528,12 +528,17 @@ int32_t llm_chat_apply_template( } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token - for (auto message : chat) { - std::string role(message->role); - if (role == "user") { - ss << "User: " << message->content << "\n\nAssistant:"; - } else { - ss << message->content << "\n\n"; + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (role == "system") { + ss << "System: " << trim(chat[i]->content) << "\n\n"; + } else if (role == "user") { + ss << "User: " << trim(chat[i]->content) << "\n\n"; + if (i == chat.size() - 1) { + ss << "Assistant:"; + } + } else if (role == "assistant") { + ss << "Assistant: " << trim(chat[i]->content) << "\n\n"; } } } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 19b247b0d672f..154b37cdb01d0 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -292,6 +292,7 @@ int main(int argc, char ** argv) { if (!params.system_prompt.empty() || !params.prompt.empty()) { common_chat_templates_inputs inputs; + inputs.use_jinja = g_params->use_jinja; inputs.messages = chat_msgs; inputs.add_generation_prompt = !params.prompt.empty();