Skip to content

Commit

Permalink
Learn flag (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Feb 21, 2024
1 parent 4498a76 commit 456e012
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 12 deletions.
9 changes: 8 additions & 1 deletion mdagent/mainagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
subagents_model="gpt-4-1106-preview",
ckpt_dir="ckpt",
resume=False,
learn=True,
top_k_tools=20, # set "all" if you want to use all tools (& skills if resume)
use_human_tool=False,
curriculum=True,
Expand All @@ -70,7 +71,11 @@ def __init__(
callbacks=[StreamingStdOutCallbackHandler()],
)

# assign prompt
if learn:
self.skip_subagents = False
else:
self.skip_subagents = True

if agent_type == "Structured":
self.prompt = structured_prompt
elif agent_type == "OpenAIFunctionsAgent":
Expand Down Expand Up @@ -99,13 +104,15 @@ def _initialize_tools_and_agent(self, user_input=None):
llm=self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
skip_subagents=self.skip_subagents,
)
else:
# retrieve all tools, including new tools if any
self.tools = make_all_tools(
self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
skip_subagents=self.skip_subagents,
)
return AgentExecutor.from_agent_and_tools(
tools=self.tools,
Expand Down
8 changes: 4 additions & 4 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def make_all_tools(
ModifyBaseSimulationScriptTool(path_registry=path_instance, llm=llm),
SimulationOutputFigures(),
]

# tools using subagents
if subagent_settings is None:
subagent_settings = SubAgentSettings(path_registry=path_instance)

# tools using subagents
subagents_tools = []
if not skip_subagents:
subagents_tools = [
Expand Down Expand Up @@ -129,7 +129,7 @@ def get_tools(
llm: BaseLanguageModel,
subagent_settings: Optional[SubAgentSettings] = None,
top_k_tools=15,
subagents_required=True,
skip_subagents=False,
human=False,
):
if subagent_settings:
Expand All @@ -138,7 +138,7 @@ def get_tools(
ckpt_dir = "ckpt"

retrieved_tools = []
if subagents_required:
if not skip_subagents:
# add subagents-related tools by default
retrieved_tools = [
CreateNewTool(subagent_settings=subagent_settings),
Expand Down
13 changes: 7 additions & 6 deletions st_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@
# Streamlit app
st.title("MDAgent")

# option = st.selectbox("Choose an option:", ["Explore & Learn", "Use Learned Skills"])
# if option == "Explore & Learn":
# explore = True
# else:
# explore = False
option = st.selectbox("Choose an option:", ["Explore & Learn", "Use Learned Skills"])
if option == "Explore & Learn":
learn = True
else:
learn = False

resume_op = st.selectbox("Resume:", ["False", "True"])
if resume_op == "True":
resume = True
else:
resume = False


# for now I'm just going to allow pdb and cif files - we can add more later
uploaded_files = st.file_uploader(
"Upload a .pdb or .cif file", type=["pdb", "cif"], accept_multiple_files=True
Expand All @@ -45,7 +46,7 @@
else:
uploaded_file = []

mdagent = MDAgent(resume=resume, uploaded_files=uploaded_file)
mdagent = MDAgent(resume=resume, uploaded_files=uploaded_file, learn=learn)


def generate_response(prompt):
Expand Down
8 changes: 7 additions & 1 deletion tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,13 @@ def test_update_skill_library(skill_manager):
)


# test mdagent with and without curriculum
def test_mdagent_learn_init():
mdagent_skill = MDAgent(learn=False)
assert mdagent_skill.skip_subagents is True
mdagent_learn = MDAgent(learn=True)
assert mdagent_learn.skip_subagents is False


def test_mdagent_curriculum():
mdagent_curr = MDAgent(curriculum=True)
mdagent_no_curr = MDAgent(curriculum=False)
Expand Down

0 comments on commit 456e012

Please sign in to comment.