Skip to content
This repository has been archived by the owner on Sep 16, 2024. It is now read-only.

New feature: Dialogue tuning #36

Merged
merged 2 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions backend/src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub struct CompanionData {
pub long_term_mem: usize,
pub short_term_mem: u32,
pub roleplay: u32,
pub dialogue_tuning: u32,
pub avatar_path: String,
}

Expand Down Expand Up @@ -60,12 +61,13 @@ impl Database {
long_term_mem INTEGER NOT NULL,
short_term_mem INTEGER NOT NULL,
roleplay INTEGER NOT NULL,
dialogue_tuning INTEGER NOT NULL,
avatar_path STRING NOT NULL
)", [],
)?;
if Database::is_table_empty("companion", &con) {
con.execute(
"INSERT INTO companion (id, name, persona, example_dialogue, first_message, long_term_mem, short_term_mem, roleplay, avatar_path) VALUES (NULL, \"Assistant\", \"{{char}} is an artificial intelligence chatbot designed to help {{user}}. {{char}} is an artificial intelligence created in ai-companion backend\", \"{{user}}: What is ai-companion?\n{{char}}: AI Companion is a project that aims to provide users with their own personal AI chatbot on their computer. It allows users to engage in friendly and natural conversations with their AI, creating a unique and personalized experience. This software can also be used as a backend or API for other projects that require a personalised AI chatbot.\n{{user}}: Can you tell me about the creator of ai-companion?\n{{char}}: the creator of the ai-companion program is 'Hubert Kasperek', he is a young programmer from Poland who is mostly interested in: web development (Backend), cybersecurity and computer science concepts\", \"Hello {{user}}, how can i help you?\", 2, 5, 1, \"/assets/companion_avatar-4rust.jpg\")", []
"INSERT INTO companion (id, name, persona, example_dialogue, first_message, long_term_mem, short_term_mem, roleplay, dialogue_tuning, avatar_path) VALUES (NULL, \"Assistant\", \"{{char}} is an artificial intelligence chatbot designed to help {{user}}. {{char}} is an artificial intelligence created in ai-companion backend\", \"{{user}}: What is ai-companion?\n{{char}}: AI Companion is a project that aims to provide users with their own personal AI chatbot on their computer. It allows users to engage in friendly and natural conversations with their AI, creating a unique and personalized experience. This software can also be used as a backend or API for other projects that require a personalised AI chatbot.\n{{user}}: Can you tell me about the creator of ai-companion?\n{{char}}: the creator of the ai-companion program is 'Hubert Kasperek', he is a young programmer from Poland who is mostly interested in: web development (Backend), cybersecurity and computer science concepts\", \"Hello {{user}}, how can i help you?\", 2, 5, 1, 1, \"/assets/companion_avatar-4rust.jpg\")", []
)?;
}
if Database::is_table_empty("user", &con) {
Expand Down Expand Up @@ -141,7 +143,8 @@ impl Database {
long_term_mem: row.get(5)?,
short_term_mem: row.get(6)?,
roleplay: row.get(7)?,
avatar_path: row.get(8)?,
dialogue_tuning: row.get(8)?,
avatar_path: row.get(9)?,
})
})?;
let mut result: CompanionData = Default::default();
Expand Down Expand Up @@ -228,9 +231,9 @@ impl Database {
Ok(())
}

pub fn change_companion(name: &str, persona: &str, example_dialogue: &str, first_message: &str, long_term_mem: u32, short_term_mem: u32, roleplay: bool) -> Result<(), Error> {
pub fn change_companion(name: &str, persona: &str, example_dialogue: &str, first_message: &str, long_term_mem: u32, short_term_mem: u32, roleplay: bool, dialogue_tuning: bool) -> Result<(), Error> {
let con = Connection::open("companion.db")?;
con.execute(&format!("UPDATE companion SET name=?1, persona=?2, example_dialogue=?3, first_message=?4, long_term_mem={}, short_term_mem={}, roleplay={}", long_term_mem, short_term_mem, roleplay), [&name, &persona, &example_dialogue, &first_message])?;
con.execute(&format!("UPDATE companion SET name=?1, persona=?2, example_dialogue=?3, first_message=?4, long_term_mem={}, short_term_mem={}, roleplay={}, dialogue_tuning={}", long_term_mem, short_term_mem, roleplay, dialogue_tuning), [&name, &persona, &example_dialogue, &first_message])?;
Ok(())
}

Expand Down Expand Up @@ -287,4 +290,10 @@ impl Database {
con.execute(&format!("UPDATE companion SET roleplay={}", op), [])?;
Ok(())
}

pub fn disable_enable_dialogue_tuning(op: bool) -> Result<(), Error> {
let con = Connection::open("companion.db")?;
con.execute(&format!("UPDATE companion SET dialogue_tuning={}", op), [])?;
Ok(())
}
}
53 changes: 53 additions & 0 deletions backend/src/dialogue_tuning.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use rusqlite::{Connection, Error, Result};

#[derive(Debug)]
pub struct Dialogue {
pub user_msg: String,
pub ai_msg: String,
}

pub struct DialogueTuning {}

impl DialogueTuning {
pub fn create() -> Result<usize, Error> {
let con = Connection::open("companion.db")?;
con.execute(
"CREATE TABLE IF NOT EXISTS dialogues (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_msg TEXT NOT NULL,
ai_msg TEXT NOT NULL
)", [],
)
}

pub fn add_dialogue(user_msg: &str, ai_msg: &str) -> Result<(), Error> {
let con = Connection::open("companion.db")?;
con.execute(
"INSERT INTO dialogues (user_msg, ai_msg) VALUES (?1, ?2)",
&[user_msg, ai_msg],
)?;
Ok(())
}

pub fn get_random_dialogue() -> Result<Dialogue, Error> {
let con = Connection::open("companion.db")?;

let query = "SELECT user_msg, ai_msg FROM dialogues WHERE id = (SELECT id FROM dialogues ORDER BY RANDOM() LIMIT 1);";

if let Some(row) = con.query_row(query, [], |row| {
let user_msg: String = row.get(0)?;
let ai_msg: String = row.get(1)?;
Ok(Dialogue { user_msg, ai_msg })
}).ok() {
Ok(row)
} else {
Err(Error::QueryReturnedNoRows)
}
}

pub fn clear_dialogues() -> Result<(), Error> {
let con = Connection::open("companion.db")?;
con.execute("DELETE FROM dialogues", [])?;
Ok(())
}
}
60 changes: 57 additions & 3 deletions backend/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ mod vectordb;
use vectordb::VectorDatabase;
mod prompt;
use prompt::prompt;
mod dialogue_tuning;
use dialogue_tuning::DialogueTuning;

#[get("/")]
async fn index() -> HttpResponse {
Expand Down Expand Up @@ -120,6 +122,24 @@ async fn regenerate_message() -> HttpResponse {
}
}

#[post("/api/saveTunedDialogue")]
async fn save_tuned_dialogue() -> HttpResponse {
let messages = match Database::get_x_msgs(2) {
Ok(v) => v,
Err(e) => {
eprintln!("Error while fetching two previous messages from sqlite database: {}", e);
return HttpResponse::InternalServerError().body("Error while adding tuned dialogue, check logs for more information");
}
};
match DialogueTuning::add_dialogue(&messages[0].text, &messages[1].text) {
Ok(_) => HttpResponse::Ok().body("Saved messages as tuned dialogue"),
Err(e) => {
eprintln!("Error while saving messages as tuned dialogue: {}", e);
return HttpResponse::InternalServerError().body("Error while adding tuned dialogue, check logs for more information");
},
}
}

#[get("/api/messages")]
async fn get_messages() -> HttpResponse {
let messages: Vec<Message> = match Database::get_messages() {
Expand All @@ -142,6 +162,15 @@ async fn clear_messages() -> HttpResponse {
HttpResponse::Ok().body("Chat log cleared")
}

#[get("/api/clearTuningMessages")]
async fn clear_tuning_dialogues() -> HttpResponse {
match DialogueTuning::clear_dialogues() {
Ok(_) => {},
Err(e) => eprintln!("Error while removing tuning dialogue messages from sqlite database: {}", e),
};
HttpResponse::Ok().body("Tuning dialogue messages erased")
}

#[derive(Deserialize)]
struct ModMsg {
new_text: String,
Expand Down Expand Up @@ -306,11 +335,12 @@ struct ChangeCompanionData {
long_term_mem: u32,
short_term_mem: u32,
roleplay: bool,
dialogue_tuning: bool
}

#[post("/api/change/companionData")]
async fn change_companion_data(received: web::Json<ChangeCompanionData>) -> HttpResponse {
match Database::change_companion(&received.name, &received.persona, &received.example_dialogue, &received.first_message, received.long_term_mem, received.short_term_mem, received.roleplay) {
match Database::change_companion(&received.name, &received.persona, &received.example_dialogue, &received.first_message, received.long_term_mem, received.short_term_mem, received.roleplay, received.dialogue_tuning) {
Ok(_) => HttpResponse::Ok().body("Data of your ai companion has been changed"),
Err(e) => {
eprintln!("Error while changing companion data in sqlite database: {}", e);
Expand Down Expand Up @@ -405,12 +435,12 @@ async fn change_short_term_mem(received: web::Json<ChangeMemory>) -> HttpRespons
}

#[derive(Deserialize)]
struct ChangeRoleplay {
struct ChangeSwitch {
enable: bool,
}

#[post("/api/change/roleplay")]
async fn change_roleplay(received: web::Json<ChangeRoleplay>) -> HttpResponse {
async fn change_roleplay(received: web::Json<ChangeSwitch>) -> HttpResponse {
match Database::disable_enable_roleplay(received.enable) {
Ok(_) => {},
Err(e) => {
Expand All @@ -425,6 +455,22 @@ async fn change_roleplay(received: web::Json<ChangeRoleplay>) -> HttpResponse {
}
}

#[post("/api/change/dialogue_tuning")]
async fn change_dialogue_tuning(received: web::Json<ChangeSwitch>) -> HttpResponse {
match Database::disable_enable_dialogue_tuning(received.enable) {
Ok(_) => {},
Err(e) => {
eprintln!("Error while enabling/disabling dialogue tuning in sqlite database: {}", e);
return HttpResponse::InternalServerError().body("Error while enabling/disabling dialogue tuning, check logs for more information");
},
};
if received.enable {
HttpResponse::Ok().body("Enabled dialogue tuning")
} else {
HttpResponse::Ok().body("Disabled dialogue tuning")
}
}

// works with https://zoltanai.github.io/character-editor/
// and with https://github.com/Hukasx0/aichar
#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -659,6 +705,11 @@ async fn main() -> std::io::Result<()> {
Err(e) => { eprintln!("Cannot connect to tantivy because of: {}",e); }
}

match DialogueTuning::create() {
Ok(_) => {},
Err(e) => { eprintln!("Cannot create Dialogue Tuning table in Sqlite database because of {}", e); }
}

println!("AI companion works at:\n -> http://{}:{}/", hostname, port);
println!("You can access it, by entering a link in your browser:\n -> http://localhost:{}/", port);
HttpServer::new(|| {
Expand All @@ -671,9 +722,11 @@ async fn main() -> std::io::Result<()> {
.service(project_logo)
.service(do_prompt)
.service(regenerate_message)
.service(save_tuned_dialogue)
.service(get_messages)
.service(edit_message)
.service(clear_messages)
.service(clear_tuning_dialogues)
.service(rm_message)
.service(change_first_message)
.service(change_companion_name)
Expand All @@ -690,6 +743,7 @@ async fn main() -> std::io::Result<()> {
.service(change_long_term_mem)
.service(change_short_term_mem)
.service(change_roleplay)
.service(change_dialogue_tuning)
.service(change_companion_avatar)
.service(import_character_json)
.service(import_character_card)
Expand Down
18 changes: 14 additions & 4 deletions backend/src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use chrono::{DateTime, Local};
use crate::Database;
use crate::database::{Message, CompanionData, UserData};
use crate::vectordb::VectorDatabase;
use crate::dialogue_tuning::DialogueTuning;

pub fn prompt(text_prompt: &str) -> Result<String, String> {
let vector = match VectorDatabase::connect() {
Expand Down Expand Up @@ -68,17 +69,26 @@ pub fn prompt(text_prompt: &str) -> Result<String, String> {
};
let mut base_prompt: String;
let mut rp: &str = "";
let mut tuned_dialogue: String = String::from("");
if companion.roleplay == 1 {
rp = "gestures and other non-verbal actions are written between asterisks (for example, *waves hello* or *moves closer*)";
}
if companion.dialogue_tuning == 1 {
match DialogueTuning::get_random_dialogue() {
Ok(dialogue) => {
tuned_dialogue = format!("{}: {}\n{}: {}", &user.name, &dialogue.user_msg, &companion.name, &dialogue.ai_msg);
},
Err(_) => {},
};
}
if is_llama2 {
base_prompt =
format!("<<SYS>>\nYou are {}, {}\nyou are talking with {}, {} is {}\n{}\n[INST]\n{}\n[/INST]",
companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), user.name, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), rp, companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name));
format!("<<SYS>>\nYou are {}, {}\nyou are talking with {}, {} is {}\n{}\n[INST]\n{}\n{}\n[/INST]",
companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), user.name, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), rp, companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), &tuned_dialogue);
} else {
base_prompt =
format!("Text transcript of a conversation between {} and {}. {}\n{}'s Persona: {}\n{}'s Persona: {}\n<START>{}\n<START>\n",
user.name, companion.name, rp, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name));
format!("Text transcript of a conversation between {} and {}. {}\n{}'s Persona: {}\n{}'s Persona: {}\n<START>\n{}\n<START>\n{}\n<START>\n",
user.name, companion.name, rp, user.name, user.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.name, companion.persona.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), companion.example_dialogue.replace("{{char}}", &companion.name).replace("{{user}}", &user.name), &tuned_dialogue);
}
let abstract_memory: Vec<String> = match vector.get_matches(text_prompt, companion.long_term_mem) {
Ok(m) => m,
Expand Down
17 changes: 17 additions & 0 deletions src/components/ChatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ const MessagesList = (companionData: CompanionData | undefined, messages: Messag
</div>
<div className="flex space-x-2 pl-14">
{index === messages.length - 1 && index != 0 && !regeneratingMessage && (
<>
<div className="chat-footer tiny-text opacity-50 cursor-pointer" onClick={() => {
setRegeneratingMessage(true);
fetch('/api/regenerate_message', {
Expand All @@ -123,6 +124,22 @@ const MessagesList = (companionData: CompanionData | undefined, messages: Messag
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="#555A63"><path d="M16.242 17.242a6.04 6.04 0 0 1-1.37 1.027l.961 1.754a8.068 8.068 0 0 0 2.569-2.225l-1.6-1.201a5.938 5.938 0 0 1-.56.645zm1.743-4.671a5.975 5.975 0 0 1-.362 2.528l1.873.701a7.977 7.977 0 0 0 .483-3.371l-1.994.142zm1.512-2.368a8.048 8.048 0 0 0-1.841-2.859l-1.414 1.414a6.071 6.071 0 0 1 1.382 2.146l1.873-.701zm-8.128 8.763c-.047-.005-.094-.015-.141-.021a6.701 6.701 0 0 1-.468-.075 5.923 5.923 0 0 1-2.421-1.122 5.954 5.954 0 0 1-.583-.506 6.138 6.138 0 0 1-.516-.597 5.91 5.91 0 0 1-.891-1.634 6.086 6.086 0 0 1-.247-.902c-.008-.043-.012-.088-.019-.131A6.332 6.332 0 0 1 6 13.002V13c0-1.603.624-3.109 1.758-4.242A5.944 5.944 0 0 1 11 7.089V10l5-4-5-4v3.069a7.917 7.917 0 0 0-4.656 2.275A7.936 7.936 0 0 0 4 12.999v.009c0 .253.014.504.037.753.007.076.021.15.03.227.021.172.044.345.076.516.019.1.044.196.066.295.032.142.065.283.105.423.032.112.07.223.107.333.026.079.047.159.076.237l.008-.003A7.948 7.948 0 0 0 5.6 17.785l-.007.005c.021.028.049.053.07.081.211.272.433.538.681.785a8.236 8.236 0 0 0 .966.816c.265.192.537.372.821.529l.028.019.001-.001a7.877 7.877 0 0 0 2.136.795l-.001.005.053.009c.201.042.405.071.61.098.069.009.138.023.207.03a8.038 8.038 0 0 0 2.532-.137l-.424-1.955a6.11 6.11 0 0 1-1.904.102z"></path></svg>
</div>
</div>
<div className="chat-footer tiny-text opacity-50 cursor-pointer" onClick={() => {
fetch('/api/saveTunedDialogue', {
method: 'POST',
})
.then(response => {
console.log(response);
})
.catch(error => {
console.error('Error while saving message as tuned dialogue: ', error);
})
}}>
<div className="tooltip tooltip-bottom" data-tip="I like this answer">
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="#555A63"><path d="M20 8h-5.612l1.123-3.367c.202-.608.1-1.282-.275-1.802S14.253 2 13.612 2H12c-.297 0-.578.132-.769.36L6.531 8H4c-1.103 0-2 .897-2 2v9c0 1.103.897 2 2 2h13.307a2.01 2.01 0 0 0 1.873-1.298l2.757-7.351A1 1 0 0 0 22 12v-2c0-1.103-.897-2-2-2zM4 10h2v9H4v-9zm16 1.819L17.307 19H8V9.362L12.468 4h1.146l-1.562 4.683A.998.998 0 0 0 13 10h7v1.819z"></path></svg>
</div>
</div>
</>

)}
{!(regeneratingMessage && index === messages.length - 1) &&
Expand Down
Loading
Loading