Skip to content

Commit

Permalink
feat: save function calls in the session (#994)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 13, 2024
1 parent ff0ea19 commit cfa9217
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 15 deletions.
31 changes: 29 additions & 2 deletions src/client/message.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::ToolResults;

use crate::utils::dimmed_text;

use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Deserialize, Serialize)]
Expand Down Expand Up @@ -53,6 +55,7 @@ pub enum MessageRole {
System,
Assistant,
User,
Tool,
}

#[allow(dead_code)]
Expand All @@ -76,7 +79,11 @@ pub enum MessageContent {
}

impl MessageContent {
pub fn render_input(&self, resolve_url_fn: impl Fn(&str) -> String) -> String {
pub fn render_input(
&self,
resolve_url_fn: impl Fn(&str) -> String,
agent_info: &Option<(String, Vec<String>)>,
) -> String {
match self {
MessageContent::Text(text) => text.to_string(),
MessageContent::Array(list) => {
Expand All @@ -96,7 +103,27 @@ impl MessageContent {
}
format!(".file {}{}", files.join(" "), concated_text)
}
MessageContent::ToolResults(_) => String::new(),
MessageContent::ToolResults(results) => {
let ToolResults {
tool_results, text, ..
} = results;
let mut lines = vec![];
if !text.is_empty() {
lines.push(text.clone())
}
for tool_result in tool_results {
let mut parts = vec!["Call".to_string()];
if let Some((agent_name, functions)) = agent_info {
if functions.contains(&tool_result.call.name) {
parts.push(agent_name.clone())
}
}
parts.push(tool_result.call.name.clone());
parts.push(tool_result.call.arguments.to_string());
lines.push(dimmed_text(&parts.join(" ")));
}
lines.join("\n")
}
}
}

Expand Down
26 changes: 23 additions & 3 deletions src/client/model.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{
list_chat_models, list_embedding_models, list_reranker_models,
message::{Message, MessageContent},
message::{Message, MessageContent, MessageContentPart},
ToolResults,
};

use crate::config::Config;
Expand Down Expand Up @@ -229,8 +230,27 @@ impl Model {
.iter()
.map(|v| match &v.content {
MessageContent::Text(text) => estimate_token_length(text),
MessageContent::Array(_) => 0,
MessageContent::ToolResults(_) => 0,
MessageContent::Array(list) => list
.iter()
.map(|v| match v {
MessageContentPart::Text { text } => estimate_token_length(text),
MessageContentPart::ImageUrl { .. } => 0,
})
.sum(),
MessageContent::ToolResults(results) => {
let ToolResults {
tool_results, text, ..
} = results;
estimate_token_length(text)
+ tool_results
.iter()
.map(|v| {
serde_json::to_string(v)
.map(|v| estimate_token_length(&v))
.unwrap_or_default()
})
.sum::<usize>()
}
})
.sum()
}
Expand Down
16 changes: 10 additions & 6 deletions src/config/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub struct Input {
regenerate: bool,
medias: Vec<String>,
data_urls: HashMap<String, String>,
tool_call: Option<ToolResults>,
tool_results: Option<ToolResults>,
rag_name: Option<String>,
role: Role,
with_session: bool,
Expand All @@ -48,7 +48,7 @@ impl Input {
regenerate: false,
medias: Default::default(),
data_urls: Default::default(),
tool_call: None,
tool_results: None,
rag_name: None,
role,
with_session,
Expand Down Expand Up @@ -104,7 +104,7 @@ impl Input {
regenerate: false,
medias,
data_urls,
tool_call: Default::default(),
tool_results: Default::default(),
rag_name: None,
role,
with_session,
Expand All @@ -120,6 +120,10 @@ impl Input {
self.data_urls.clone()
}

pub fn tool_results(&self) -> &Option<ToolResults> {
&self.tool_results
}

pub fn text(&self) -> String {
match self.patched_text.clone() {
Some(text) => text,
Expand Down Expand Up @@ -184,11 +188,11 @@ impl Input {
}

pub fn merge_tool_call(mut self, output: String, tool_results: Vec<ToolResult>) -> Self {
match self.tool_call.as_mut() {
match self.tool_results.as_mut() {
Some(exist_tool_results) => {
exist_tool_results.extend(tool_results, output);
}
None => self.tool_call = Some(ToolResults::new(tool_results, output)),
None => self.tool_results = Some(ToolResults::new(tool_results, output)),
}
self
}
Expand Down Expand Up @@ -228,7 +232,7 @@ impl Input {
} else {
self.role().build_messages(self)
};
if let Some(tool_results) = &self.tool_call {
if let Some(tool_results) = &self.tool_results {
messages.push(Message::new(
MessageRole::Assistant,
MessageContent::ToolResults(tool_results.clone()),
Expand Down
11 changes: 10 additions & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1050,7 +1050,16 @@ impl Config {
if let Some(session) = &self.session {
let render_options = self.render_options()?;
let mut markdown_render = MarkdownRender::init(render_options)?;
session.render(&mut markdown_render)
let agent_info: Option<(String, Vec<String>)> = self.agent.as_ref().map(|agent| {
let functions = agent
.functions()
.declarations()
.iter()
.filter_map(|v| if v.agent { Some(v.name.clone()) } else { None })
.collect();
(agent.name().to_string(), functions)
});
session.render(&mut markdown_render, &agent_info)
} else {
bail!("No session")
}
Expand Down
22 changes: 19 additions & 3 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,11 @@ impl Session {
Ok(output)
}

pub fn render(&self, render: &mut MarkdownRender) -> Result<String> {
pub fn render(
&self,
render: &mut MarkdownRender,
agent_info: &Option<(String, Vec<String>)>,
) -> Result<String> {
let mut items = vec![];

if let Some(path) = &self.path {
Expand Down Expand Up @@ -205,7 +209,10 @@ impl Session {
for message in &self.messages {
match message.role {
MessageRole::System => {
lines.push(render.render(&message.content.render_input(resolve_url_fn)));
lines.push(
render
.render(&message.content.render_input(resolve_url_fn, agent_info)),
);
}
MessageRole::Assistant => {
if let MessageContent::Text(text) = &message.content {
Expand All @@ -217,9 +224,12 @@ impl Session {
lines.push(format!(
"{}){}",
self.name,
message.content.render_input(resolve_url_fn)
message.content.render_input(resolve_url_fn, agent_info)
));
}
MessageRole::Tool => {
lines.push(message.content.render_input(resolve_url_fn, agent_info));
}
}
}
}
Expand Down Expand Up @@ -409,6 +419,12 @@ impl Session {
.push(Message::new(MessageRole::User, input.message_content()));
}
self.data_urls.extend(input.data_urls());
if let Some(tool_results) = input.tool_results() {
self.messages.push(Message::new(
MessageRole::Tool,
MessageContent::ToolResults(tool_results.clone()),
))
}
self.messages.push(Message::new(
MessageRole::Assistant,
MessageContent::Text(output.to_string()),
Expand Down

0 comments on commit cfa9217

Please sign in to comment.