Skip to content

Commit

Permalink
refactor: improve code regarding tools/agents (#1021)
Browse files Browse the repository at this point in the history
  • Loading branch information
sigoden authored Nov 30, 2024
1 parent 580b40e commit 50bbfc4
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 71 deletions.
39 changes: 28 additions & 11 deletions src/config/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@ use serde::{Deserialize, Serialize};

const DEFAULT_AGENT_NAME: &str = "rag";

pub type AgentVariables = IndexMap<String, String>;

#[derive(Debug, Clone, Serialize)]
pub struct Agent {
name: String,
config: AgentConfig,
definition: AgentDefinition,
#[serde(skip)]
shared_variables: IndexMap<String, String>,
shared_variables: AgentVariables,
#[serde(skip)]
session_variables: Option<IndexMap<String, String>>,
session_variables: Option<AgentVariables>,
#[serde(skip)]
functions: Functions,
#[serde(skip)]
Expand Down Expand Up @@ -108,9 +110,9 @@ impl Agent {

pub fn init_agent_variables(
agent_variables: &[AgentVariable],
variables: &IndexMap<String, String>,
variables: &AgentVariables,
no_interaction: bool,
) -> Result<IndexMap<String, String>> {
) -> Result<AgentVariables> {
let mut output = IndexMap::new();
if agent_variables.is_empty() {
return Ok(output);
Expand Down Expand Up @@ -224,26 +226,41 @@ impl Agent {
self.config.agent_prelude = value;
}

pub fn variables(&self) -> &IndexMap<String, String> {
pub fn variables(&self) -> &AgentVariables {
match &self.session_variables {
Some(variables) => variables,
None => &self.shared_variables,
}
}

pub fn config_variables(&self) -> &IndexMap<String, String> {
pub fn variable_envs(&self) -> HashMap<String, String> {
self.variables()
.iter()
.map(|(k, v)| {
(
format!("LLM_AGENT_VAR_{}", normalize_env_name(k)),
v.clone(),
)
})
.collect()
}

pub fn config_variables(&self) -> &AgentVariables {
&self.config.variables
}

pub fn shared_variables(&self) -> &IndexMap<String, String> {
pub fn shared_variables(&self) -> &AgentVariables {
&self.shared_variables
}

pub fn set_shared_variables(&mut self, shared_variables: IndexMap<String, String>) {
pub fn set_shared_variables(&mut self, shared_variables: AgentVariables) {
self.shared_variables = shared_variables;
}

pub fn set_session_variables(&mut self, session_variables: Option<IndexMap<String, String>>) {
pub fn set_session_variables(&mut self, session_variables: Option<AgentVariables>) {
if self.shared_variables.is_empty() {
self.shared_variables = session_variables.clone().unwrap_or_default();
}
self.session_variables = session_variables;
}

Expand Down Expand Up @@ -319,7 +336,7 @@ pub struct AgentConfig {
pub use_tools: Option<String>,
pub agent_prelude: Option<String>,
#[serde(default)]
pub variables: IndexMap<String, String>,
pub variables: AgentVariables,
}

impl AgentConfig {
Expand Down Expand Up @@ -419,7 +436,7 @@ impl AgentDefinition {
)
}

fn interpolated_instructions(&self, variables: &IndexMap<String, String>) -> String {
fn interpolated_instructions(&self, variables: &AgentVariables) -> String {
let mut output = self.instructions.clone();
for (k, v) in variables {
output = output.replace(&format!("{{{{{k}}}}}"), v)
Expand Down
12 changes: 6 additions & 6 deletions src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod input;
mod role;
mod session;

pub use self::agent::{list_agents, Agent};
pub use self::agent::{list_agents, Agent, AgentVariables};
pub use self::input::Input;
pub use self::role::{
Role, RoleLike, CODE_ROLE, CREATE_TITLE_ROLE, EXPLAIN_SHELL_ROLE, SHELL_ROLE,
Expand Down Expand Up @@ -1487,10 +1487,13 @@ impl Config {
if parts.len() != 2 {
bail!("Usage: .variable <key> <value>");
}
let key = parts[0];
let value = parts[1];
match self.agent.as_mut() {
Some(agent) => {
if let Some(session) = self.session.as_ref() {
session.guard_empty()?;
}
let key = parts[0];
let value = parts[1];
agent.set_variable(key, value)?;
if let Some(session) = self.session.as_mut() {
session.sync_agent(agent, true);
Expand Down Expand Up @@ -2009,9 +2012,6 @@ impl Config {
&all_variables,
self.print_info_only,
)?;
if shared_variables.is_empty() {
agent.set_shared_variables(new_variables.clone());
}
agent.set_session_variables(Some(new_variables));
session.sync_agent(agent, false);
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions src/config/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub struct Session {
#[serde(skip_serializing_if = "Option::is_none")]
role_name: Option<String>,
#[serde(default, skip_serializing_if = "IndexMap::is_empty")]
agent_variables: IndexMap<String, String>,
agent_variables: AgentVariables,

#[serde(default, skip_serializing_if = "Vec::is_empty")]
compressed_messages: Vec<Message>,
Expand Down Expand Up @@ -284,7 +284,7 @@ impl Session {
}
}

pub fn agent_variables(&self) -> &IndexMap<String, String> {
pub fn agent_variables(&self) -> &AgentVariables {
&self.agent_variables
}

Expand Down
112 changes: 60 additions & 52 deletions src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,43 +159,37 @@ impl ToolCall {

pub fn eval(&self, config: &GlobalConfig) -> Result<Value> {
let function_name = self.name.clone();
let (call_name, cmd_name, mut cmd_args, mut envs) = match &config.read().agent {
let (call_name, cmd_name, mut cmd_args, envs, agent_name) = match &config.read().agent {
Some(agent) => match agent.functions().find(&function_name) {
Some(function) => {
let agent_name = agent.name().to_string();
if function.agent {
let envs: HashMap<String, String> = agent
.variables()
.iter()
.map(|(k, v)| {
(
format!("LLM_AGENT_VAR_{}", normalize_env_name(k)),
v.clone(),
)
})
.collect();
(
format!("{}:{}", agent.name(), function_name),
agent.name().to_string(),
format!("{agent_name}-{function_name}"),
agent_name.clone(),
vec![function_name],
envs,
agent.variable_envs(),
Some(agent_name),
)
} else {
(
function_name.clone(),
function_name,
vec![],
Default::default(),
Some(agent_name),
)
}
}
None => bail!("Unexpected call {function_name} {}", self.arguments),
None => bail!("Unexpected call: {function_name} {}", self.arguments),
},
None => match config.read().functions.contains(&function_name) {
true => (
function_name.clone(),
function_name,
vec![],
Default::default(),
None,
),
false => bail!("Unexpected call: {function_name} {}", self.arguments),
},
Expand All @@ -215,50 +209,64 @@ impl ToolCall {
};

cmd_args.push(json_data.to_string());
let prompt = format!("Call {cmd_name} {}", cmd_args.join(" "));

let mut bin_dirs: Vec<PathBuf> = vec![];
if let Some(agent) = config.read().agent.as_ref() {
let dir = Config::agent_functions_dir(agent.name()).join("bin");
if dir.exists() {
bin_dirs.push(dir);
}
}
bin_dirs.push(Config::functions_bin_dir());
let current_path = std::env::var("PATH").context("No PATH environment variable")?;
let prepend_path = bin_dirs
.iter()
.map(|v| format!("{}{PATH_SEP}", v.display()))
.collect::<Vec<_>>()
.join("");
envs.insert("PATH".into(), format!("{prepend_path}{current_path}"));
let output = match run_llm_function(cmd_name, cmd_args, envs, agent_name)? {
Some(contents) => serde_json::from_str(&contents)
.ok()
.unwrap_or_else(|| json!({"result": contents})),
None => Value::Null,
};

let temp_file = temp_file("-eval-", "");
envs.insert("LLM_OUTPUT".into(), temp_file.display().to_string());
Ok(output)
}
}

#[cfg(windows)]
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
if *IS_STDOUT_TERMINAL {
println!("{}", dimmed_text(&prompt));
}
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
.map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?;
if exit_code != 0 {
bail!("Tool call exit with {exit_code}");
fn run_llm_function(
cmd_name: String,
cmd_args: Vec<String>,
mut envs: HashMap<String, String>,
agent_name: Option<String>,
) -> Result<Option<String>> {
let prompt = format!("Call {cmd_name} {}", cmd_args.join(" "));

let mut bin_dirs: Vec<PathBuf> = vec![];
if let Some(agent_name) = agent_name {
let dir = Config::agent_functions_dir(&agent_name).join("bin");
if dir.exists() {
bin_dirs.push(dir);
}
let output = if temp_file.exists() {
let contents =
fs::read_to_string(temp_file).context("Failed to retrieve tool call output")?;
}
bin_dirs.push(Config::functions_bin_dir());
let current_path = std::env::var("PATH").context("No PATH environment variable")?;
let prepend_path = bin_dirs
.iter()
.map(|v| format!("{}{PATH_SEP}", v.display()))
.collect::<Vec<_>>()
.join("");
envs.insert("PATH".into(), format!("{prepend_path}{current_path}"));

serde_json::from_str(&contents)
.ok()
.unwrap_or_else(|| json!({"result": contents}))
} else {
Value::Null
};
let temp_file = temp_file("-eval-", "");
envs.insert("LLM_OUTPUT".into(), temp_file.display().to_string());

Ok(output)
#[cfg(windows)]
let cmd_name = polyfill_cmd_name(&cmd_name, &bin_dirs);
if *IS_STDOUT_TERMINAL {
println!("{}", dimmed_text(&prompt));
}
let exit_code = run_command(&cmd_name, &cmd_args, Some(envs))
.map_err(|err| anyhow!("Unable to run {cmd_name}, {err}"))?;
if exit_code != 0 {
bail!("Tool call exit with {exit_code}");
}
let mut output = None;
if temp_file.exists() {
let contents =
fs::read_to_string(temp_file).context("Failed to retrieve tool call output")?;
if !contents.is_empty() {
output = Some(contents);
}
};
Ok(output)
}

#[cfg(windows)]
Expand Down

0 comments on commit 50bbfc4

Please sign in to comment.