From 80684ec84047bda0299f8fd845b1c97f6fac5c99 Mon Sep 17 00:00:00 2001 From: sigoden Date: Thu, 14 Nov 2024 08:14:32 +0800 Subject: [PATCH] refactor: improve tool calls (#995) - rename MessageContent:ToolResults to MessageContent:ToolCalls - rename ToolResults to MessageContentToolCalls - persist tool_calls to messages.md --- src/client/bedrock.rs | 7 +++---- src/client/claude.rs | 7 +++---- src/client/cohere.rs | 4 ++-- src/client/ernie.rs | 6 ++++-- src/client/message.rs | 40 ++++++++++++++++++++++++++++++---------- src/client/mod.rs | 2 +- src/client/model.rs | 9 ++++----- src/client/openai.rs | 5 ++--- src/client/vertexai.rs | 3 +-- src/config/input.rs | 26 +++++++++++++------------- src/config/mod.rs | 18 ++++++++++++++++-- src/config/session.rs | 4 ++-- src/function.rs | 23 ----------------------- src/main.rs | 2 +- src/repl/mod.rs | 2 +- src/serve.rs | 2 +- 16 files changed, 84 insertions(+), 76 deletions(-) diff --git a/src/client/bedrock.rs b/src/client/bedrock.rs index e1d6c657..78f9d2be 100644 --- a/src/client/bedrock.rs +++ b/src/client/bedrock.rs @@ -363,10 +363,9 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu "content": content, })] } - MessageContent::ToolResults(results) => { - let ToolResults { - tool_results, text, .. - } = results; + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { let mut assistant_parts = vec![]; let mut user_parts = vec![]; if !text.is_empty() { diff --git a/src/client/claude.rs b/src/client/claude.rs index 0c905b1a..c7ef88a8 100644 --- a/src/client/claude.rs +++ b/src/client/claude.rs @@ -202,10 +202,9 @@ pub fn claude_build_chat_completions_body( "content": content, })] } - MessageContent::ToolResults(results) => { - let ToolResults { - tool_results, text, .. - } = results; + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { let mut assistant_parts = vec![]; let mut user_parts = vec![]; if !text.is_empty() { diff --git a/src/client/cohere.rs b/src/client/cohere.rs index c5a31f5d..97263009 100644 --- a/src/client/cohere.rs +++ b/src/client/cohere.rs @@ -213,8 +213,8 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Resu .collect(); Some(json!({ "role": role, "message": list.join("\n\n") })) } - MessageContent::ToolResults(results) => { - tool_results = Some(results.tool_results); + MessageContent::ToolCalls(tool_calls) => { + tool_results = Some(tool_calls.tool_results); None } } diff --git a/src/client/ernie.rs b/src/client/ernie.rs index 0c325e57..7e1f8e70 100644 --- a/src/client/ernie.rs +++ b/src/client/ernie.rs @@ -231,9 +231,11 @@ fn build_chat_completions_body(data: ChatCompletionsData, model: &Model) -> Valu .flat_map(|message| { let Message { role, content } = message; match content { - MessageContent::ToolResults(results) => { + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, .. + }) => { let mut list = vec![]; - for tool_result in results.tool_results { + for tool_result in tool_results { list.push(json!({ "role": "assistant", "content": format!("Action: {}\nAction Input: {}", tool_result.call.name, tool_result.call.arguments) diff --git a/src/client/message.rs b/src/client/message.rs index 061c0e2a..f6b7f6d2 100644 --- a/src/client/message.rs +++ b/src/client/message.rs @@ -1,6 +1,4 @@ -use super::ToolResults; - -use crate::utils::dimmed_text; +use crate::{function::ToolResult, utils::dimmed_text}; use serde::{Deserialize, Serialize}; @@ -75,7 +73,7 @@ pub enum MessageContent { Text(String), Array(Vec), // Note: This type is primarily for convenience and does not exist in OpenAI's API. - ToolResults(ToolResults), + ToolCalls(MessageContentToolCalls), } impl MessageContent { @@ -103,10 +101,9 @@ impl MessageContent { } format!(".file {}{}", files.join(" "), concated_text) } - MessageContent::ToolResults(results) => { - let ToolResults { - tool_results, text, .. - } = results; + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { let mut lines = vec![]; if !text.is_empty() { lines.push(text.clone()) @@ -139,7 +136,7 @@ impl MessageContent { *text = replace_fn(text) } } - MessageContent::ToolResults(_) => {} + MessageContent::ToolCalls(_) => {} } } @@ -155,7 +152,7 @@ impl MessageContent { } parts.join("\n\n") } - MessageContent::ToolResults(_) => String::new(), + MessageContent::ToolCalls(_) => String::new(), } } } @@ -172,6 +169,29 @@ pub struct ImageUrl { pub url: String, } +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct MessageContentToolCalls { + pub tool_results: Vec, + pub text: String, + pub sequence: bool, +} + +impl MessageContentToolCalls { + pub fn new(tool_results: Vec, text: String) -> Self { + Self { + tool_results, + text, + sequence: false, + } + } + + pub fn merge(&mut self, tool_results: Vec, _text: String) { + self.tool_results.extend(tool_results); + self.text.clear(); + self.sequence = true; + } +} + pub fn patch_system_message(messages: &mut Vec) { if messages[0].role.is_system() { let system_message = messages.remove(0); diff --git a/src/client/mod.rs b/src/client/mod.rs index b22508b5..5189f9f0 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -6,7 +6,7 @@ mod macros; mod model; mod stream; -pub use crate::function::{ToolCall, ToolResults}; +pub use crate::function::ToolCall; pub use crate::utils::PromptKind; pub use common::*; pub use message::*; diff --git a/src/client/model.rs b/src/client/model.rs index af864e2a..5d496d23 100644 --- a/src/client/model.rs +++ b/src/client/model.rs @@ -1,7 +1,7 @@ use super::{ list_chat_models, list_embedding_models, list_reranker_models, message::{Message, MessageContent, MessageContentPart}, - ToolResults, + MessageContentToolCalls, }; use crate::config::Config; @@ -237,10 +237,9 @@ impl Model { MessageContentPart::ImageUrl { .. } => 0, }) .sum(), - MessageContent::ToolResults(results) => { - let ToolResults { - tool_results, text, .. - } = results; + MessageContent::ToolCalls(MessageContentToolCalls { + tool_results, text, .. + }) => { estimate_token_length(text) + tool_results .iter() diff --git a/src/client/openai.rs b/src/client/openai.rs index 0f01e36f..7f449fb8 100644 --- a/src/client/openai.rs +++ b/src/client/openai.rs @@ -205,12 +205,11 @@ pub fn openai_build_chat_completions_body(data: ChatCompletionsData, model: &Mod .flat_map(|message| { let Message { role, content } = message; match content { - MessageContent::ToolResults(results) => { - let ToolResults { + MessageContent::ToolCalls(MessageContentToolCalls { tool_results, text, sequence, - } = results; + }) => { if !sequence { let tool_calls: Vec<_> = tool_results.iter().map(|tool_result| { json!({ diff --git a/src/client/vertexai.rs b/src/client/vertexai.rs index 452eb090..07ff2a43 100644 --- a/src/client/vertexai.rs +++ b/src/client/vertexai.rs @@ -342,8 +342,7 @@ pub fn gemini_build_chat_completions_body( .collect(); vec![json!({ "role": role, "parts": parts })] }, - MessageContent::ToolResults(results) => { - let tool_results = results.tool_results; + MessageContent::ToolCalls(MessageContentToolCalls { tool_results, .. }) => { let model_parts: Vec = tool_results.iter().map(|tool_result| { json!({ "functionCall": { diff --git a/src/config/input.rs b/src/config/input.rs index 55a4d45f..18b192bc 100644 --- a/src/config/input.rs +++ b/src/config/input.rs @@ -2,9 +2,9 @@ use super::*; use crate::client::{ init_client, patch_system_message, ChatCompletionsData, Client, ImageUrl, Message, - MessageContent, MessageContentPart, MessageRole, Model, + MessageContent, MessageContentPart, MessageContentToolCalls, MessageRole, Model, }; -use crate::function::{ToolResult, ToolResults}; +use crate::function::ToolResult; use crate::utils::{base64_encode, sha256, AbortSignal}; use anyhow::{bail, Context, Result}; @@ -29,7 +29,7 @@ pub struct Input { regenerate: bool, medias: Vec, data_urls: HashMap, - tool_results: Option, + tool_calls: Option, rag_name: Option, role: Role, with_session: bool, @@ -48,7 +48,7 @@ impl Input { regenerate: false, medias: Default::default(), data_urls: Default::default(), - tool_results: None, + tool_calls: None, rag_name: None, role, with_session, @@ -104,7 +104,7 @@ impl Input { regenerate: false, medias, data_urls, - tool_results: Default::default(), + tool_calls: Default::default(), rag_name: None, role, with_session, @@ -120,8 +120,8 @@ impl Input { self.data_urls.clone() } - pub fn tool_results(&self) -> &Option { - &self.tool_results + pub fn tool_calls(&self) -> &Option { + &self.tool_calls } pub fn text(&self) -> String { @@ -187,12 +187,12 @@ impl Input { self.rag_name.as_deref() } - pub fn merge_tool_call(mut self, output: String, tool_results: Vec) -> Self { - match self.tool_results.as_mut() { + pub fn merge_tool_results(mut self, output: String, tool_results: Vec) -> Self { + match self.tool_calls.as_mut() { Some(exist_tool_results) => { - exist_tool_results.extend(tool_results, output); + exist_tool_results.merge(tool_results, output); } - None => self.tool_results = Some(ToolResults::new(tool_results, output)), + None => self.tool_calls = Some(MessageContentToolCalls::new(tool_results, output)), } self } @@ -232,10 +232,10 @@ impl Input { } else { self.role().build_messages(self) }; - if let Some(tool_results) = &self.tool_results { + if let Some(tool_calls) = &self.tool_calls { messages.push(Message::new( MessageRole::Assistant, - MessageContent::ToolResults(tool_results.clone()), + MessageContent::ToolCalls(tool_calls.clone()), )) } Ok(messages) diff --git a/src/config/mod.rs b/src/config/mod.rs index 9da8daea..136fee17 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -10,7 +10,7 @@ use self::session::Session; use crate::client::{ create_client_config, list_chat_models, list_client_types, list_reranker_models, ClientConfig, - Model, OPENAI_COMPATIBLE_PLATFORMS, + MessageContentToolCalls, Model, OPENAI_COMPATIBLE_PLATFORMS, }; use crate::function::{FunctionDeclaration, Functions, ToolResult}; use crate::rag::Rag; @@ -1863,8 +1863,22 @@ impl Config { } else { String::new() }; + let tool_calls = match input.tool_calls() { + Some(MessageContentToolCalls { + tool_results, text, .. + }) => { + let mut lines = vec!["".to_string()]; + if !text.is_empty() { + lines.push(text.clone()); + } + lines.push(serde_json::to_string(&tool_results).unwrap_or_default()); + lines.push("\n".to_string()); + lines.join("\n") + } + None => String::new(), + }; let output = format!( - "# CHAT: {summary} [{timestamp}]{scope}\n{raw_input}\n--------\n{output}\n--------\n\n", + "# CHAT: {summary} [{timestamp}]{scope}\n{raw_input}\n--------\n{tool_calls}{output}\n--------\n\n", ); file.write_all(output.as_bytes()) .with_context(|| "Failed to save message") diff --git a/src/config/session.rs b/src/config/session.rs index 34d63931..633f6215 100644 --- a/src/config/session.rs +++ b/src/config/session.rs @@ -419,10 +419,10 @@ impl Session { .push(Message::new(MessageRole::User, input.message_content())); } self.data_urls.extend(input.data_urls()); - if let Some(tool_results) = input.tool_results() { + if let Some(tool_calls) = input.tool_calls() { self.messages.push(Message::new( MessageRole::Tool, - MessageContent::ToolResults(tool_results.clone()), + MessageContent::ToolCalls(tool_calls.clone()), )) } self.messages.push(Message::new( diff --git a/src/function.rs b/src/function.rs index f77467b9..e4a7454f 100644 --- a/src/function.rs +++ b/src/function.rs @@ -266,29 +266,6 @@ impl ToolCall { } } -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct ToolResults { - pub tool_results: Vec, - pub text: String, - pub sequence: bool, -} - -impl ToolResults { - pub fn new(tool_results: Vec, text: String) -> Self { - Self { - tool_results, - text, - sequence: false, - } - } - - pub fn extend(&mut self, tool_results: Vec, _text: String) { - self.tool_results.extend(tool_results); - self.text.clear(); - self.sequence = true; - } -} - #[cfg(windows)] fn polyfill_cmd_name>(cmd_name: &str, bin_dir: &[T]) -> String { let cmd_name = cmd_name.to_string(); diff --git a/src/main.rs b/src/main.rs index 413301cd..d06b5ab7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -188,7 +188,7 @@ async fn start_directive( if need_send_tool_results(&tool_results) { start_directive( config, - input.merge_tool_call(output, tool_results), + input.merge_tool_results(output, tool_results), code_mode, abort_signal, ) diff --git a/src/repl/mod.rs b/src/repl/mod.rs index 7f971b1d..1e841463 100644 --- a/src/repl/mod.rs +++ b/src/repl/mod.rs @@ -649,7 +649,7 @@ async fn ask( ask( config, abort_signal, - input.merge_tool_call(output, tool_results), + input.merge_tool_results(output, tool_results), false, ) .await diff --git a/src/serve.rs b/src/serve.rs index 84e6c870..c19be1f5 100644 --- a/src/serve.rs +++ b/src/serve.rs @@ -889,7 +889,7 @@ fn parse_messages(message: Vec) -> Result> { } output.push(Message::new( MessageRole::Assistant, - MessageContent::ToolResults(ToolResults::new(list, text)), + MessageContent::ToolCalls(MessageContentToolCalls::new(list, text)), )); tool_results = None; } else {