Skip to content

Commit

Permalink
refactor: separate out ChatContext from the Aichat specific format
Browse files Browse the repository at this point in the history
  • Loading branch information
arcuru committed Oct 15, 2024
1 parent cedd657 commit 58c20aa
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 112 deletions.
18 changes: 8 additions & 10 deletions src/aichat.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::process::Command;
use tracing::info;

use crate::ChatContext;

pub struct AiChat {
binary_location: String,
config_dir: Option<String>,
Expand Down Expand Up @@ -65,28 +67,24 @@ impl AiChat {
.unwrap_or("default".to_string())
}

pub fn execute(
&self,
model: &Option<String>,
prompt: String,
media: Vec<matrix_sdk::media::MediaFileHandle>,
) -> Result<String, String> {
pub fn execute(&self, context: &ChatContext) -> Result<String, String> {
let mut command = Command::new(&self.binary_location);
if let Some(model) = model {
if let Some(model) = &context.model {
command.arg("--model").arg(model);
}
if let Some(config_dir) = &self.config_dir {
command.env("AICHAT_CONFIG_DIR", config_dir);
}
// For each media file, add the media flag and the path to the file
// Note that we must not consume the media files, the handles need to persist until the command is finished
if !media.is_empty() {
if !context.media.is_empty() {
command.arg("--file");
for media_file in &media {
for media_file in &context.media {
command.arg(media_file.path());
}
}
command.arg("--").arg(prompt);
// Adds the full prompt as just a string
command.arg("--").arg(context.string_prompt_with_role());
info!("Running command: {:?}", command);

let output = command.output().expect("Failed to execute command");
Expand Down
220 changes: 135 additions & 85 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod aichat;
use aichat::AiChat;

mod role;
use role::RoleDetails;
use role::{get_role, prepend_role, RoleDetails};

mod defaults;
use defaults::DEFAULT_CONFIG;
Expand Down Expand Up @@ -134,9 +134,8 @@ async fn main() -> anyhow::Result<()> {
None,
Some("Print the conversation".to_string()),
|_, _, room| async move {
let (context, _, _) = get_context(&room).await.unwrap();
let context = add_role(&context);
let content = RoomMessageEventContent::notice_plain(context);
let context = get_context(&room).await.unwrap();
let content = RoomMessageEventContent::notice_plain(context.string_prompt());
room.send(content).await.unwrap();
Ok(())
},
Expand All @@ -159,14 +158,23 @@ async fn main() -> anyhow::Result<()> {
.join(" ");

// But we do need to read the context to figure out the model to use
let (_, model, _) = get_context(&room).await.unwrap();
let context = get_context(&room).await.unwrap();
let no_context = ChatContext {
messages: vec![Message {
sender: "USER".to_string(),
content: input.to_string(),
}],
model: context.model,
role: context.role,
media: Vec::new(),
};

info!(
"Request: {} - {}",
sender.as_str(),
input.replace('\n', " ")
);
if let Ok(result) = get_backend().execute(&model, input.to_string(), Vec::new()) {
if let Ok(result) = get_backend().execute(&no_context) {
// Add the prefix ".response:\n" to the result
// That way we can identify our own responses and ignore them for context
info!(
Expand Down Expand Up @@ -250,17 +258,8 @@ async fn main() -> anyhow::Result<()> {
return Ok(());
}
// If it's not a command, we should send the full context without commands to the server
if let Ok((context, model, media)) = get_context(&room).await {
let mut context = add_role(&context);
// Append "ASSISTANT: " to the context string to indicate the assistant is speaking
context.push_str("ASSISTANT: ");

info!(
"Request: {} - {}",
sender.as_str(),
context.replace('\n', " ")
);
match get_backend().execute(&model, context, media) {
if let Ok(context) = get_context(&room).await {
match get_backend().execute(&context) {
Ok(stdout) => {
info!("Response: {}", stdout.replace('\n', " "));
// Most LLMs like responding with Markdown
Expand Down Expand Up @@ -288,17 +287,6 @@ async fn main() -> anyhow::Result<()> {
Ok(())
}

/// Prepend the role defined in the global config
fn add_role(context: &str) -> String {
let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap();
role::prepend_role(
context.to_string(),
config.role.clone(),
config.roles.clone(),
DEFAULT_CONFIG.roles.clone(),
)
}

/// Rate limit the user to a set number of messages
/// Returns true if the user is being rate limited
async fn rate_limit(room: &Room, sender: &OwnedUserId) -> bool {
Expand Down Expand Up @@ -354,10 +342,10 @@ async fn rate_limit(room: &Room, sender: &OwnedUserId) -> bool {

/// List the available models
async fn list_models(_: OwnedUserId, _: String, room: Room) -> Result<(), ()> {
let (_, current_model, _) = get_context(&room).await.unwrap();
let context = get_context(&room).await.unwrap();
let response = format!(
"!chaz Current Model: {}\n\nAvailable Models:\n{}",
current_model.unwrap_or(get_backend().default_model()),
context.model.unwrap_or(get_backend().default_model()),
get_backend().list_models().join("\n")
);
room.send(RoomMessageEventContent::notice_plain(response))
Expand Down Expand Up @@ -398,23 +386,20 @@ async fn rename(sender: OwnedUserId, _: String, room: Room) -> Result<(), ()> {
if rate_limit(&room, &sender).await {
return Ok(());
}
if let Ok((context, _, _)) = get_context(&room).await {
let title_prompt= [
&context,
"\nUSER: Summarize this conversation in less than 20 characters to use as the title of this conversation. ",
"The output should be a single line of text describing the conversation. ",
"Do not output anything except for the summary text. ",
"Only the first 20 characters will be used. ",
"\nASSISTANT: ",
].join("");
let model = get_chat_summary_model();

info!(
"Request: {} - {}",
sender.as_str(),
title_prompt.replace('\n', " ")
);
let response = get_backend().execute(&model, title_prompt, Vec::new());
if let Ok(context) = get_context(&room).await {
let mut context = context;
context.model = get_chat_summary_model();
context.messages.push(Message {
sender: "USER".to_string(),
content: [
"Summarize this conversation in less than 20 characters to use as the title of this conversation.",
"The output should be a single line of text describing the conversation.",
"Do not output anything except for the summary text.",
"Only the first 20 characters will be used.",
].join(" "),
});

let response = get_backend().execute(&context);
if let Ok(result) = response {
info!(
"Response: {} - {}",
Expand All @@ -433,22 +418,21 @@ async fn rename(sender: OwnedUserId, _: String, room: Room) -> Result<(), ()> {
return Ok(());
}
}

let topic_prompt = [
&context,
"\nUSER: Summarize this conversation in less than 50 characters. ",
"Do not output anything except for the summary text. ",
"Do not include any commentary or context, only the summary. ",
"\nASSISTANT: ",
]
.join("");

info!(
"Request: {} - {}",
sender.as_str(),
topic_prompt.replace('\n', " ")
);
let response = get_backend().execute(&model, topic_prompt, Vec::new());
// Remove the title summary request
context.messages.pop();

context.model = get_chat_summary_model();
context.messages.push(Message {
sender: "USER".to_string(),
content: [
"Summarize this conversation in less than 50 characters.",
"Do not output anything except for the summary text.",
"Do not include any commentary or context, only the summary.",
]
.join(" "),
});

let response = get_backend().execute(&context);
if let Ok(result) = response {
info!(
"Response: {} - {}",
Expand Down Expand Up @@ -500,21 +484,78 @@ fn get_chat_summary_model() -> Option<String> {
config.chat_summary_model
}

struct Message {
sender: String,
content: String,
}

impl Message {
fn new<S: Into<String>>(sender: S, content: S) -> Message {
Message {
sender: sender.into(),
content: content.into(),
}
}
}

struct ChatContext {
messages: Vec<Message>,
model: Option<String>,
media: Vec<MediaFileHandle>,
role: Option<RoleDetails>,
}

impl ChatContext {
/// Convert messages into a single string.
fn string_prompt(&self) -> String {
// TODO: consider making this markdown
let mut prompt = String::new();
for message in self.messages.iter() {
prompt.push_str(&format!("{}: {}\n", message.sender, message.content));
}
// Indicate that the assistant needs to speak next
prompt.push_str("ASSISTANT: ");
prompt
}

/// Convert messages into a single string with the role prepended
fn string_prompt_with_role(&self) -> String {
let prompt = self.string_prompt();
if let Some(role) = &self.role {
prepend_role(prompt, role)
} else {
prompt
}
}
}

/// Gets the context of the current conversation
/// Returns a model if it was ever entered
async fn get_context(room: &Room) -> Result<(String, Option<String>, Vec<MediaFileHandle>), ()> {
// Read all the messages in the room, place them into a single string, and print them out
let mut messages = Vec::new();
///
/// The token_limit is the maximum number of tokens to add into the context.
/// If no token_limit is given, the context will include the full room
async fn get_context(room: &Room) -> Result<ChatContext, ()> {
let mut context = ChatContext {
messages: Vec::new(),
model: None,
media: Vec::new(),
role: None,
};
{
let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap();
context.role = get_role(
config.role.clone(),
config.roles.clone(),
DEFAULT_CONFIG.roles.clone(),
);
}

let mut options = MessagesOptions::backward();
let mut model_response = None;
let mut media = Vec::new();

let config = GLOBAL_CONFIG.lock().unwrap().clone().unwrap();
let enable_media_context = !config.disable_media_context.unwrap_or(false);

'outer: while let Ok(batch) = room.messages(options).await {
// This assumes that the messages are in reverse order
// This assumes that the messages are in reverse order, which they should be
for message in batch.chunk {
if let Some((sender, content)) = message
.event
Expand Down Expand Up @@ -549,7 +590,7 @@ async fn get_context(room: &Room) -> Result<(String, Option<String>, Vec<MediaFi
.get_media_file(&request, None, &mime, true, None)
.await
.unwrap();
media.insert(0, x);
context.media.push(x);
}
}
MessageType::Text(text_content) => {
Expand All @@ -558,14 +599,14 @@ async fn get_context(room: &Room) -> Result<(String, Option<String>, Vec<MediaFi
// if the message is a valid model command, set the model
// FIXME: hardcoded name
if text_content.body.starts_with("!chaz model")
&& model_response.is_none()
&& context.model.is_none()
{
let model = text_content.body.split_whitespace().nth(2);
if let Some(model) = model {
// Add the config_dir from the global config
let models = get_backend().list_models();
if models.contains(&model.to_string()) {
model_response = Some(model.to_string());
context.model = Some(model.to_string());
}
}
}
Expand Down Expand Up @@ -595,9 +636,15 @@ async fn get_context(room: &Room) -> Result<(String, Option<String>, Vec<MediaFi
.user_id()
.is_some_and(|uid| sender == uid.as_str())
{
messages.push(format!("ASSISTANT: {}\n", command));
context.messages.push(Message::new(
"ASSISTANT".to_string(),
command.to_string(),
));
} else {
messages.push(format!("USER: {}\n", command));
context.messages.push(Message::new(
"USER".to_string(),
command.to_string(),
));
}
}
} else {
Expand All @@ -607,11 +654,16 @@ async fn get_context(room: &Room) -> Result<(String, Option<String>, Vec<MediaFi
.user_id()
.is_some_and(|uid| sender == uid.as_str())
{
// If the sender is the bot, prefix the message with "ASSISTANT: "
messages.push(format!("ASSISTANT: {}\n", text_content.body));
// Sender is the bot
context.messages.push(Message::new(
"ASSISTANT".to_string(),
text_content.body.clone(),
));
} else {
// Otherwise, prefix the message with "USER: "
messages.push(format!("USER: {}\n", text_content.body));
context.messages.push(Message::new(
"USER".to_string(),
text_content.body.clone(),
));
}
}
}
Expand All @@ -625,10 +677,8 @@ async fn get_context(room: &Room) -> Result<(String, Option<String>, Vec<MediaFi
break;
}
}
// Append the messages into a string with newlines in between, in reverse order
Ok((
messages.into_iter().rev().collect::<String>(),
model_response,
media,
))
// Reverse context so that it's in the correct order
context.messages.reverse();
context.media.reverse();
Ok(context)
}
Loading

0 comments on commit 58c20aa

Please sign in to comment.