diff --git a/backend/src/database.rs b/backend/src/database.rs index 276360c..8195842 100644 --- a/backend/src/database.rs +++ b/backend/src/database.rs @@ -20,6 +20,7 @@ pub struct CompanionData { pub long_term_mem: usize, pub short_term_mem: u32, pub roleplay: u32, + pub dialogue_tuning: u32, pub avatar_path: String, } @@ -60,12 +61,13 @@ impl Database { long_term_mem INTEGER NOT NULL, short_term_mem INTEGER NOT NULL, roleplay INTEGER NOT NULL, + dialogue_tuning INTEGER NOT NULL, avatar_path STRING NOT NULL )", [], )?; if Database::is_table_empty("companion", &con) { con.execute( - "INSERT INTO companion (id, name, persona, example_dialogue, first_message, long_term_mem, short_term_mem, roleplay, avatar_path) VALUES (NULL, \"Assistant\", \"{{char}} is an artificial intelligence chatbot designed to help {{user}}. {{char}} is an artificial intelligence created in ai-companion backend\", \"{{user}}: What is ai-companion?\n{{char}}: AI Companion is a project that aims to provide users with their own personal AI chatbot on their computer. It allows users to engage in friendly and natural conversations with their AI, creating a unique and personalized experience. This software can also be used as a backend or API for other projects that require a personalised AI chatbot.\n{{user}}: Can you tell me about the creator of ai-companion?\n{{char}}: the creator of the ai-companion program is 'Hubert Kasperek', he is a young programmer from Poland who is mostly interested in: web development (Backend), cybersecurity and computer science concepts\", \"Hello {{user}}, how can i help you?\", 2, 5, 1, \"/assets/companion_avatar-4rust.jpg\")", [] + "INSERT INTO companion (id, name, persona, example_dialogue, first_message, long_term_mem, short_term_mem, roleplay, dialogue_tuning, avatar_path) VALUES (NULL, \"Assistant\", \"{{char}} is an artificial intelligence chatbot designed to help {{user}}. {{char}} is an artificial intelligence created in ai-companion backend\", \"{{user}}: What is ai-companion?\n{{char}}: AI Companion is a project that aims to provide users with their own personal AI chatbot on their computer. It allows users to engage in friendly and natural conversations with their AI, creating a unique and personalized experience. This software can also be used as a backend or API for other projects that require a personalised AI chatbot.\n{{user}}: Can you tell me about the creator of ai-companion?\n{{char}}: the creator of the ai-companion program is 'Hubert Kasperek', he is a young programmer from Poland who is mostly interested in: web development (Backend), cybersecurity and computer science concepts\", \"Hello {{user}}, how can i help you?\", 2, 5, 1, 1, \"/assets/companion_avatar-4rust.jpg\")", [] )?; } if Database::is_table_empty("user", &con) { @@ -141,7 +143,8 @@ impl Database { long_term_mem: row.get(5)?, short_term_mem: row.get(6)?, roleplay: row.get(7)?, - avatar_path: row.get(8)?, + dialogue_tuning: row.get(8)?, + avatar_path: row.get(9)?, }) })?; let mut result: CompanionData = Default::default(); @@ -228,9 +231,9 @@ impl Database { Ok(()) } - pub fn change_companion(name: &str, persona: &str, example_dialogue: &str, first_message: &str, long_term_mem: u32, short_term_mem: u32, roleplay: bool) -> Result<(), Error> { + pub fn change_companion(name: &str, persona: &str, example_dialogue: &str, first_message: &str, long_term_mem: u32, short_term_mem: u32, roleplay: bool, dialogue_tuning: bool) -> Result<(), Error> { let con = Connection::open("companion.db")?; - con.execute(&format!("UPDATE companion SET name=?1, persona=?2, example_dialogue=?3, first_message=?4, long_term_mem={}, short_term_mem={}, roleplay={}", long_term_mem, short_term_mem, roleplay), [&name, &persona, &example_dialogue, &first_message])?; + con.execute(&format!("UPDATE companion SET name=?1, persona=?2, example_dialogue=?3, first_message=?4, long_term_mem={}, short_term_mem={}, roleplay={}, dialogue_tuning={}", long_term_mem, short_term_mem, roleplay, dialogue_tuning), [&name, &persona, &example_dialogue, &first_message])?; Ok(()) } @@ -287,4 +290,10 @@ impl Database { con.execute(&format!("UPDATE companion SET roleplay={}", op), [])?; Ok(()) } + + pub fn disable_enable_dialogue_tuning(op: bool) -> Result<(), Error> { + let con = Connection::open("companion.db")?; + con.execute(&format!("UPDATE companion SET dialogue_tuning={}", op), [])?; + Ok(()) + } } diff --git a/backend/src/dialogue_tuning.rs b/backend/src/dialogue_tuning.rs new file mode 100644 index 0000000..c617e9c --- /dev/null +++ b/backend/src/dialogue_tuning.rs @@ -0,0 +1,53 @@ +use rusqlite::{Connection, Error, Result}; + +#[derive(Debug)] +pub struct Dialogue { + pub user_msg: String, + pub ai_msg: String, +} + +pub struct DialogueTuning {} + +impl DialogueTuning { + pub fn create() -> Result { + let con = Connection::open("companion.db")?; + con.execute( + "CREATE TABLE IF NOT EXISTS dialogues ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_msg TEXT NOT NULL, + ai_msg TEXT NOT NULL + )", [], + ) + } + + pub fn add_dialogue(user_msg: &str, ai_msg: &str) -> Result<(), Error> { + let con = Connection::open("companion.db")?; + con.execute( + "INSERT INTO dialogues (user_msg, ai_msg) VALUES (?1, ?2)", + &[user_msg, ai_msg], + )?; + Ok(()) + } + + pub fn get_random_dialogue() -> Result { + let con = Connection::open("companion.db")?; + + let query = "SELECT user_msg, ai_msg FROM dialogues WHERE id = (SELECT id FROM dialogues ORDER BY RANDOM() LIMIT 1);"; + + if let Some(row) = con.query_row(query, [], |row| { + let user_msg: String = row.get(0)?; + let ai_msg: String = row.get(1)?; + Ok(Dialogue { user_msg, ai_msg }) + }).ok() { + Ok(row) + } else { + Err(Error::QueryReturnedNoRows) + } + } + + pub fn clear_dialogues() -> Result<(), Error> { + let con = Connection::open("companion.db")?; + con.execute("DELETE FROM dialogues", [])?; + Ok(()) + } +} diff --git a/backend/src/main.rs b/backend/src/main.rs index 5cf15d6..55e5750 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -12,6 +12,8 @@ mod vectordb; use vectordb::VectorDatabase; mod prompt; use prompt::prompt; +mod dialogue_tuning; +use dialogue_tuning::DialogueTuning; #[get("/")] async fn index() -> HttpResponse { @@ -120,6 +122,24 @@ async fn regenerate_message() -> HttpResponse { } } +#[post("/api/saveTunedDialogue")] +async fn save_tuned_dialogue() -> HttpResponse { + let messages = match Database::get_x_msgs(2) { + Ok(v) => v, + Err(e) => { + eprintln!("Error while fetching two previous messages from sqlite database: {}", e); + return HttpResponse::InternalServerError().body("Error while adding tuned dialogue, check logs for more information"); + } + }; + match DialogueTuning::add_dialogue(&messages[0].text, &messages[1].text) { + Ok(_) => HttpResponse::Ok().body("Saved messages as tuned dialogue"), + Err(e) => { + eprintln!("Error while saving messages as tuned dialogue: {}", e); + return HttpResponse::InternalServerError().body("Error while adding tuned dialogue, check logs for more information"); + }, + } +} + #[get("/api/messages")] async fn get_messages() -> HttpResponse { let messages: Vec = match Database::get_messages() { @@ -142,6 +162,15 @@ async fn clear_messages() -> HttpResponse { HttpResponse::Ok().body("Chat log cleared") } +#[get("/api/clearTuningMessages")] +async fn clear_tuning_dialogues() -> HttpResponse { + match DialogueTuning::clear_dialogues() { + Ok(_) => {}, + Err(e) => eprintln!("Error while removing tuning dialogue messages from sqlite database: {}", e), + }; + HttpResponse::Ok().body("Tuning dialogue messages erased") +} + #[derive(Deserialize)] struct ModMsg { new_text: String, @@ -306,11 +335,12 @@ struct ChangeCompanionData { long_term_mem: u32, short_term_mem: u32, roleplay: bool, + dialogue_tuning: bool } #[post("/api/change/companionData")] async fn change_companion_data(received: web::Json) -> HttpResponse { - match Database::change_companion(&received.name, &received.persona, &received.example_dialogue, &received.first_message, received.long_term_mem, received.short_term_mem, received.roleplay) { + match Database::change_companion(&received.name, &received.persona, &received.example_dialogue, &received.first_message, received.long_term_mem, received.short_term_mem, received.roleplay, received.dialogue_tuning) { Ok(_) => HttpResponse::Ok().body("Data of your ai companion has been changed"), Err(e) => { eprintln!("Error while changing companion data in sqlite database: {}", e); @@ -405,12 +435,12 @@ async fn change_short_term_mem(received: web::Json) -> HttpRespons } #[derive(Deserialize)] -struct ChangeRoleplay { +struct ChangeSwitch { enable: bool, } #[post("/api/change/roleplay")] -async fn change_roleplay(received: web::Json) -> HttpResponse { +async fn change_roleplay(received: web::Json) -> HttpResponse { match Database::disable_enable_roleplay(received.enable) { Ok(_) => {}, Err(e) => { @@ -425,6 +455,22 @@ async fn change_roleplay(received: web::Json) -> HttpResponse { } } +#[post("/api/change/dialogue_tuning")] +async fn change_dialogue_tuning(received: web::Json) -> HttpResponse { + match Database::disable_enable_dialogue_tuning(received.enable) { + Ok(_) => {}, + Err(e) => { + eprintln!("Error while enabling/disabling dialogue tuning in sqlite database: {}", e); + return HttpResponse::InternalServerError().body("Error while enabling/disabling dialogue tuning, check logs for more information"); + }, + }; + if received.enable { + HttpResponse::Ok().body("Enabled dialogue tuning") + } else { + HttpResponse::Ok().body("Disabled dialogue tuning") + } +} + // works with https://zoltanai.github.io/character-editor/ // and with https://github.com/Hukasx0/aichar #[derive(Serialize, Deserialize)] @@ -659,6 +705,11 @@ async fn main() -> std::io::Result<()> { Err(e) => { eprintln!("Cannot connect to tantivy because of: {}",e); } } + match DialogueTuning::create() { + Ok(_) => {}, + Err(e) => { eprintln!("Cannot create Dialogue Tuning table in Sqlite database because of {}", e); } + } + println!("AI companion works at:\n -> http://{}:{}/", hostname, port); println!("You can access it, by entering a link in your browser:\n -> http://localhost:{}/", port); HttpServer::new(|| { @@ -671,9 +722,11 @@ async fn main() -> std::io::Result<()> { .service(project_logo) .service(do_prompt) .service(regenerate_message) + .service(save_tuned_dialogue) .service(get_messages) .service(edit_message) .service(clear_messages) + .service(clear_tuning_dialogues) .service(rm_message) .service(change_first_message) .service(change_companion_name) @@ -690,6 +743,7 @@ async fn main() -> std::io::Result<()> { .service(change_long_term_mem) .service(change_short_term_mem) .service(change_roleplay) + .service(change_dialogue_tuning) .service(change_companion_avatar) .service(import_character_json) .service(import_character_card) diff --git a/backend/src/prompt.rs b/backend/src/prompt.rs index f34164f..a7b041d 100644 --- a/backend/src/prompt.rs +++ b/backend/src/prompt.rs @@ -5,6 +5,7 @@ use chrono::{DateTime, Local}; use crate::Database; use crate::database::{Message, CompanionData, UserData}; use crate::vectordb::VectorDatabase; +use crate::dialogue_tuning::DialogueTuning; pub fn prompt(text_prompt: &str) -> Result { let vector = match VectorDatabase::connect() { @@ -68,17 +69,26 @@ pub fn prompt(text_prompt: &str) -> Result { }; let mut base_prompt: String; let mut rp: &str = ""; + let mut tuned_dialogue: String = String::from(""); if companion.roleplay == 1 { rp = "gestures and other non-verbal actions are written between asterisks (for example, *waves hello* or *moves closer*)"; } + if companion.dialogue_tuning == 1 { + match DialogueTuning::get_random_dialogue() { + Ok(dialogue) => { + tuned_dialogue = format!("{}: {}\n{}: {}", &user.name, &dialogue.user_msg, &companion.name, &dialogue.ai_msg); + }, + Err(_) => {}, + }; + } if is_llama2 { base_prompt = - format!("<>\nYou are {}, {}\nyou are talking with {}, {} is {}\n{}\n[INST]\n{}\n[/INST]", - companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), user.name, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), rp, companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name)); + format!("<>\nYou are {}, {}\nyou are talking with {}, {} is {}\n{}\n[INST]\n{}\n{}\n[/INST]", + companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), user.name, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), rp, companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), &tuned_dialogue); } else { base_prompt = - format!("Text transcript of a conversation between {} and {}. {}\n{}'s Persona: {}\n{}'s Persona: {}\n{}\n\n", - user.name, companion.name, rp, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name)); + format!("Text transcript of a conversation between {} and {}. {}\n{}'s Persona: {}\n{}'s Persona: {}\n\n{}\n\n{}\n\n", + user.name, companion.name, rp, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), &tuned_dialogue); } let abstract_memory: Vec = match vector.get_matches(text_prompt, companion.long_term_mem) { Ok(m) => m, diff --git a/src/components/ChatWindow.tsx b/src/components/ChatWindow.tsx index b2a3a50..d41bc54 100644 --- a/src/components/ChatWindow.tsx +++ b/src/components/ChatWindow.tsx @@ -103,6 +103,7 @@ const MessagesList = (companionData: CompanionData | undefined, messages: Messag
{index === messages.length - 1 && index != 0 && !regeneratingMessage && ( + <>
{ setRegeneratingMessage(true); fetch('/api/regenerate_message', { @@ -123,6 +124,22 @@ const MessagesList = (companionData: CompanionData | undefined, messages: Messag
+
{ + fetch('/api/saveTunedDialogue', { + method: 'POST', + }) + .then(response => { + console.log(response); + }) + .catch(error => { + console.error('Error while saving message as tuned dialogue: ', error); + }) + }}> +
+ +
+
+ )} {!(regeneratingMessage && index === messages.length - 1) && diff --git a/src/components/Companion.tsx b/src/components/Companion.tsx index c66ef78..39ff32e 100644 --- a/src/components/Companion.tsx +++ b/src/components/Companion.tsx @@ -25,7 +25,8 @@ const Modal = (companionData: CompanionData | undefined, setCompanionData: React if (companionData) { const updatedCompanionData = { ...companionData, - roleplay: Boolean(companionData.roleplay) + roleplay: Boolean(companionData.roleplay), + dialogue_tuning: Boolean(companionData.dialogue_tuning) }; fetch('/api/change/companionData', { @@ -200,6 +201,15 @@ const eraseButtonPressed = async () => { await eraseLongTermMem(); window.location.reload(); } + +const eraseDialogueTuningMsgs = async () => { + try { + await fetch('/api/clearTuningMessages'); + window.location.reload(); + } catch (error) { + console.log(`Error while erasing tuning messages: ${error}`); + } +} return ( <> @@ -242,7 +252,10 @@ const eraseButtonPressed = async () => {

short term memory entries (how many recent messages to remind ai at once)



- +
+ +
+ Erase liked responses

diff --git a/src/components/interfaces/CompanionData.tsx b/src/components/interfaces/CompanionData.tsx index 400f6df..758938e 100644 --- a/src/components/interfaces/CompanionData.tsx +++ b/src/components/interfaces/CompanionData.tsx @@ -7,5 +7,6 @@ interface CompanionData { long_term_mem?: number; short_term_mem?: number; roleplay?: boolean; + dialogue_tuning?: boolean; avatar_path?: string; }