From cfa9217422dfa6cf10bed6e6e3fab0722f58588d Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 14 Nov 2024 06:03:06 +0800 Subject: [PATCH] feat: save function calls in the session (#994) --- src/client/message.rs | 31 +++++++++++++++++++++++++++++-- src/client/model.rs | 26 +++++++++++++++++++++++--- src/config/input.rs | 16 ++++++++++------ src/config/mod.rs | 11 ++++++++++- src/config/session.rs | 22 +++++++++++++++++++--- 5 files changed, 91 insertions(+), 15 deletions(-) diff --git a/src/client/message.rs b/src/client/message.rs index 77adf471..061c0e2a 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -1,5 +1,7 @@ use super::ToolResults; +use crate::utils::dimmed_text; + use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Deserialize, Serialize)] @@ -53,6 +55,7 @@ pub enum MessageRole { System, Assistant, User, + Tool, } #[allow(dead_code)] @@ -76,7 +79,11 @@ pub enum MessageContent { } impl MessageContent { - pub fn render_input(&self, resolve_url_fn: impl Fn(&str) -> String) -> String { + pub fn render_input( + &self, + resolve_url_fn: impl Fn(&str) -> String, + agent_info: &Option<(String, Vec)>, + ) -> String { match self { MessageContent::Text(text) => text.to_string(), MessageContent::Array(list) => { @@ -96,7 +103,27 @@ impl MessageContent { } format!(".file {}{}", files.join(" "), concated_text) } - MessageContent::ToolResults(_) => String::new(), + MessageContent::ToolResults(results) => { + let ToolResults { + tool_results, text, .. + } = results; + let mut lines = vec![]; + if !text.is_empty() { + lines.push(text.clone()) + } + for tool_result in tool_results { + let mut parts = vec!["Call".to_string()]; + if let Some((agent_name, functions)) = agent_info { + if functions.contains(&tool_result.call.name) { + parts.push(agent_name.clone()) + } + } + parts.push(tool_result.call.name.clone()); + parts.push(tool_result.call.arguments.to_string()); + lines.push(dimmed_text(&parts.join(" "))); + } + lines.join("\n") + } } } diff --git a/src/client/model.rs b/src/client/model.rs index f8a23ff0..af864e2a 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -1,6 +1,7 @@ use super::{ list_chat_models, list_embedding_models, list_reranker_models, - message::{Message, MessageContent}, + message::{Message, MessageContent, MessageContentPart}, + ToolResults, }; use crate::config::Config; @@ -229,8 +230,27 @@ impl Model { .iter() .map(|v| match &v.content { MessageContent::Text(text) => estimate_token_length(text), - MessageContent::Array(_) => 0, - MessageContent::ToolResults(_) => 0, + MessageContent::Array(list) => list + .iter() + .map(|v| match v { + MessageContentPart::Text { text } => estimate_token_length(text), + MessageContentPart::ImageUrl { .. } => 0, + }) + .sum(), + MessageContent::ToolResults(results) => { + let ToolResults { + tool_results, text, .. + } = results; + estimate_token_length(text) + + tool_results + .iter() + .map(|v| { + serde_json::to_string(v) + .map(|v| estimate_token_length(&v)) + .unwrap_or_default() + }) + .sum::() + } }) .sum() } diff --git a/src/config/input.rs b/src/config/input.rs index bc58f321..55a4d45f 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -29,7 +29,7 @@ pub struct Input { regenerate: bool, medias: Vec, data_urls: HashMap, - tool_call: Option, + tool_results: Option, rag_name: Option, role: Role, with_session: bool, @@ -48,7 +48,7 @@ impl Input { regenerate: false, medias: Default::default(), data_urls: Default::default(), - tool_call: None, + tool_results: None, rag_name: None, role, with_session, @@ -104,7 +104,7 @@ impl Input { regenerate: false, medias, data_urls, - tool_call: Default::default(), + tool_results: Default::default(), rag_name: None, role, with_session, @@ -120,6 +120,10 @@ impl Input { self.data_urls.clone() } + pub fn tool_results(&self) -> &Option { + &self.tool_results + } + pub fn text(&self) -> String { match self.patched_text.clone() { Some(text) => text, @@ -184,11 +188,11 @@ impl Input { } pub fn merge_tool_call(mut self, output: String, tool_results: Vec) -> Self { - match self.tool_call.as_mut() { + match self.tool_results.as_mut() { Some(exist_tool_results) => { exist_tool_results.extend(tool_results, output); } - None => self.tool_call = Some(ToolResults::new(tool_results, output)), + None => self.tool_results = Some(ToolResults::new(tool_results, output)), } self } @@ -228,7 +232,7 @@ impl Input { } else { self.role().build_messages(self) }; - if let Some(tool_results) = &self.tool_call { + if let Some(tool_results) = &self.tool_results { messages.push(Message::new( MessageRole::Assistant, MessageContent::ToolResults(tool_results.clone()), diff --git a/src/config/mod.rs b/src/config/mod.rs index e6dd5c8c..9da8daea 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1050,7 +1050,16 @@ impl Config { if let Some(session) = &self.session { let render_options = self.render_options()?; let mut markdown_render = MarkdownRender::init(render_options)?; - session.render(&mut markdown_render) + let agent_info: Option<(String, Vec)> = self.agent.as_ref().map(|agent| { + let functions = agent + .functions() + .declarations() + .iter() + .filter_map(|v| if v.agent { Some(v.name.clone()) } else { None }) + .collect(); + (agent.name().to_string(), functions) + }); + session.render(&mut markdown_render, &agent_info) } else { bail!("No session") } diff --git a/src/config/session.rs b/src/config/session.rs index c072d1c8..34d63931 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -160,7 +160,11 @@ impl Session { Ok(output) } - pub fn render(&self, render: &mut MarkdownRender) -> Result { + pub fn render( + &self, + render: &mut MarkdownRender, + agent_info: &Option<(String, Vec)>, + ) -> Result { let mut items = vec![]; if let Some(path) = &self.path { @@ -205,7 +209,10 @@ impl Session { for message in &self.messages { match message.role { MessageRole::System => { - lines.push(render.render(&message.content.render_input(resolve_url_fn))); + lines.push( + render + .render(&message.content.render_input(resolve_url_fn, agent_info)), + ); } MessageRole::Assistant => { if let MessageContent::Text(text) = &message.content { @@ -217,9 +224,12 @@ impl Session { lines.push(format!( "{}){}", self.name, - message.content.render_input(resolve_url_fn) + message.content.render_input(resolve_url_fn, agent_info) )); } + MessageRole::Tool => { + lines.push(message.content.render_input(resolve_url_fn, agent_info)); + } } } } @@ -409,6 +419,12 @@ impl Session { .push(Message::new(MessageRole::User, input.message_content())); } self.data_urls.extend(input.data_urls()); + if let Some(tool_results) = input.tool_results() { + self.messages.push(Message::new( + MessageRole::Tool, + MessageContent::ToolResults(tool_results.clone()), + )) + } self.messages.push(Message::new( MessageRole::Assistant, MessageContent::Text(output.to_string()),