Skip to content

Commit

Permalink
bring back tool retrieval (#63)
Browse files Browse the repository at this point in the history
* added back tool retrieval

* fixed and tested

* updated initialize fxn

* removed old comment

* fixed bug

* separate param for user-provided tools

* removed old output file
  • Loading branch information
qcampbel authored Jan 10, 2024
1 parent 8b61d6b commit f9939b2
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 20 deletions.
52 changes: 39 additions & 13 deletions mdagent/mainagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from mdagent.subagents import SubAgentSettings
from mdagent.utils import PathRegistry, _make_llm

from ..tools import make_all_tools
from ..tools import get_tools, make_all_tools
from .prompt import openaifxn_prompt, structured_prompt

load_dotenv()
Expand Down Expand Up @@ -35,7 +35,7 @@ class MDAgent:
def __init__(
self,
tools=None,
agent_type="OpenAIFunctionsAgent", # this can also be strucured_chat
agent_type="OpenAIFunctionsAgent", # this can also be structured_chat
model="gpt-4-1106-preview", # current name for gpt-4 turbo
tools_model="gpt-4-1106-preview",
temp=0.1,
Expand All @@ -45,14 +45,16 @@ def __init__(
subagents_model="gpt-4-1106-preview",
ckpt_dir="ckpt",
resume=False,
top_k_tools=10,
top_k_tools=20, # set "all" if you want to use all tools (& skills if resume)
use_human_tool=False,
):
if path_registry is None:
path_registry = PathRegistry.get_instance()
if tools is None:
tools_llm = _make_llm(tools_model, temp, verbose)
tools = make_all_tools(tools_llm, human=use_human_tool)
self.agent_type = agent_type
self.user_tools = tools
self.tools_llm = _make_llm(tools_model, temp, verbose)
self.top_k_tools = top_k_tools
self.use_human_tool = use_human_tool

self.llm = ChatOpenAI(
temperature=temp,
Expand All @@ -61,11 +63,7 @@ def __init__(
streaming=True,
callbacks=[StreamingStdOutCallbackHandler()],
)
self.agent = AgentExecutor.from_agent_and_tools(
tools=tools,
agent=AgentType.get_agent(agent_type).from_llm_and_tools(self.llm, tools),
handle_parsing_errors=True,
)

# assign prompt
if agent_type == "Structured":
self.prompt = structured_prompt
Expand All @@ -80,9 +78,37 @@ def __init__(
verbose=verbose,
ckpt_dir=ckpt_dir,
resume=resume,
retrieval_top_k=top_k_tools,
)

def _initialize_tools_and_agent(self, user_input=None):
"""Retrieve tools and initialize the agent."""
if self.user_tools is not None:
self.tools = self.user_tools
else:
if self.top_k_tools != "all" and user_input is not None:
# retrieve only tools relevant to user input
self.tools = get_tools(
query=user_input,
llm=self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
)
else:
# retrieve all tools, including new tools if any
self.tools = make_all_tools(
self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
)
return AgentExecutor.from_agent_and_tools(
tools=self.tools,
agent=AgentType.get_agent(self.agent_type).from_llm_and_tools(
self.llm,
self.tools,
),
handle_parsing_errors=True,
)

def run(self, user_input, callbacks=None):
# todo: check this for both agent types
self.agent = self._initialize_tools_and_agent(user_input)
return self.agent.run(self.prompt.format(input=user_input), callbacks=callbacks)
18 changes: 11 additions & 7 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,22 +123,25 @@ def get_tools(
query,
llm: BaseLanguageModel,
subagent_settings: Optional[SubAgentSettings] = None,
ckpt_dir="ckpt",
retrieval_top_k=10,
top_k_tools=15,
subagents_required=True,
human=False,
):
if subagent_settings:
ckpt_dir = subagent_settings.ckpt_dir
else:
ckpt_dir = "ckpt"

retrieved_tools = []
if subagents_required:
# add subagents-related tools by default
PathRegistry.get_instance()
retrieved_tools = [
CreateNewTool(subagent_settings=subagent_settings),
RetryExecuteSkill(subagent_settings=subagent_settings),
SkillRetrieval(subagent_settings=subagent_settings),
WorkflowPlan(subagent_settings=subagent_settings),
]
retrieval_top_k -= len(retrieved_tools)
top_k_tools -= len(retrieved_tools)
all_tools = make_all_tools(
llm, subagent_settings, skip_subagents=True, human=human
)
Expand All @@ -163,7 +166,7 @@ def get_tools(
vectordb.persist()

# retrieve 'k' tools
k = min(retrieval_top_k, vectordb._collection.count())
k = min(top_k_tools, vectordb._collection.count())
if k == 0:
return None
docs = vectordb.similarity_search(query, k=k)
Expand All @@ -173,7 +176,8 @@ def get_tools(
retrieved_tools.append(all_tools[index])
else:
print(f"Invalid index {index}.")
print(f"Try deleting vectordb at {ckpt_dir}/all_tools_vectordb.")
print("Some tools may be duplicated.")
print(f"Try to delete vector DB at {ckpt_dir}/all_tools_vectordb.")
return retrieved_tools


Expand Down Expand Up @@ -217,7 +221,7 @@ def get_all_tools_string(self):
all_tools_string += f"{tool.name}: {tool.description}\n"
return all_tools_string

def _run(self, task, orig_prompt, curr_tools, execute, args=None):
def _run(self, task, orig_prompt, curr_tools, execute=True, args=None):
# run iterator
try:
all_tools_string = self.get_all_tools_string()
Expand Down
Binary file removed notebooks/.DS_Store
Binary file not shown.

0 comments on commit f9939b2

Please sign in to comment.