Skip to content

Commit

Permalink
Merge branch 'main' of github.com:ur-whitelab/md-agent into r_gyration
Browse files Browse the repository at this point in the history
merge main
  • Loading branch information
SamCox822 committed Feb 22, 2024
2 parents 4637619 + 456e012 commit 745db74
Show file tree
Hide file tree
Showing 14 changed files with 938 additions and 251 deletions.
11 changes: 10 additions & 1 deletion mdagent/mainagent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(
subagents_model="gpt-4-1106-preview",
ckpt_dir="ckpt",
resume=False,
learn=True,
top_k_tools=20, # set "all" if you want to use all tools (& skills if resume)
use_human_tool=False,
curriculum=True,
uploaded_files=[], # user input files to add to path registry
):
if path_registry is None:
Expand All @@ -69,7 +71,11 @@ def __init__(
callbacks=[StreamingStdOutCallbackHandler()],
)

# assign prompt
if learn:
self.skip_subagents = False
else:
self.skip_subagents = True

if agent_type == "Structured":
self.prompt = structured_prompt
elif agent_type == "OpenAIFunctionsAgent":
Expand All @@ -83,6 +89,7 @@ def __init__(
verbose=verbose,
ckpt_dir=ckpt_dir,
resume=resume,
curriculum=curriculum,
)

def _initialize_tools_and_agent(self, user_input=None):
Expand All @@ -97,13 +104,15 @@ def _initialize_tools_and_agent(self, user_input=None):
llm=self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
skip_subagents=self.skip_subagents,
)
else:
# retrieve all tools, including new tools if any
self.tools = make_all_tools(
self.tools_llm,
subagent_settings=self.subagents_settings,
human=self.use_human_tool,
skip_subagents=self.skip_subagents,
)
return AgentExecutor.from_agent_and_tools(
tools=self.tools,
Expand Down
5 changes: 5 additions & 0 deletions mdagent/subagents/subagent_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
ckpt_dir="ckpt",
resume=False,
retrieval_top_k=5,
curriculum=True,
):
self.path_registry = path_registry
self.subagents_model = subagents_model
Expand All @@ -24,6 +25,7 @@ def __init__(
self.ckpt_dir = ckpt_dir
self.resume = resume
self.retrieval_top_k = retrieval_top_k
self.curriculum = curriculum


class SubAgentInitializer:
Expand All @@ -40,6 +42,7 @@ def __init__(self, settings: Optional[SubAgentSettings] = None):
self.ckpt_dir = settings.ckpt_dir
self.resume = settings.resume
self.retrieval_top_k = settings.retrieval_top_k
self.curriculum = settings.curriculum

def create_action(self, **overrides):
params = {
Expand All @@ -61,6 +64,8 @@ def create_critic(self, **overrides):
return Critic(**params)

def create_curriculum(self, **overrides):
if not self.curriculum:
return None
params = {
"model": self.subagents_model,
"temp": self.temp,
Expand Down
10 changes: 8 additions & 2 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
RemoveWaterCleaningTool,
SpecializedCleanTool,
)
from .preprocess_tools.pdb_tools import Name2PDBTool, PackMolTool, get_pdb
from .preprocess_tools.pdb_tools import (
PackMolTool,
ProteinName2PDBTool,
SmallMolPDB,
get_pdb,
)
from .simulation_tools.create_simulation import ModifyBaseSimulationScriptTool
from .simulation_tools.setup_and_run import (
InstructionSummary,
Expand All @@ -32,9 +37,10 @@
"InstructionSummary",
"ListRegistryPaths",
"MapPath2Name",
"Name2PDBTool",
"ProteinName2PDBTool",
"PackMolTool",
"PPIDistance",
"SmallMolPDB",
"VisualizeProtein",
"RMSDCalculator",
"RemoveWaterCleaningTool",
Expand Down
5 changes: 3 additions & 2 deletions mdagent/tools/base_tools/preprocess_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
RemoveWaterCleaningTool,
SpecializedCleanTool,
)
from .pdb_tools import Name2PDBTool, PackMolTool, get_pdb
from .pdb_tools import PackMolTool, ProteinName2PDBTool, SmallMolPDB, get_pdb

__all__ = [
"AddHydrogensCleaningTool",
"CleaningTools",
"Name2PDBTool",
"ProteinName2PDBTool",
"PackMolTool",
"RemoveWaterCleaningTool",
"SpecializedCleanTool",
"get_pdb",
"CleaningToolFunction",
"SmallMolPDB",
]
16 changes: 8 additions & 8 deletions mdagent/tools/base_tools/preprocess_tools/clean_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,17 @@ def _run(self, **input_args) -> str:
file_description = "Cleaned File: "
CleaningTools()
try:
pdbfile = self.path_registry.get_mapped_path(pdbfile_id)
if "/" in pdbfile:
pdbfile = pdbfile.split("/")[-1]

name = pdbfile.split("_")[0]
end = pdbfile.split(".")[1]
pdbfile_path = self.path_registry.get_mapped_path(pdbfile_id)
if "/" in pdbfile_path:
pdbfile = pdbfile_path.split("/")[-1]
else:
pdbfile = pdbfile_path
name, end = pdbfile.split(".")

except Exception as e:
print(f"error retrieving from path_registry, trying to read file {e}")
return "File not found in path registry. "
fixer = PDBFixer(filename=pdbfile)
fixer = PDBFixer(filename=pdbfile_path)
try:
fixer.findMissingResidues()
except Exception:
Expand Down Expand Up @@ -353,7 +353,7 @@ def _run(self, **input_args) -> str:
file_mode = "w" if add_hydrogens else "a"
file_name = self.path_registry.write_file_name(
type=FileType.PROTEIN,
protein_name=name,
protein_name=name.split("_")[0],
description="Clean",
file_format=end,
)
Expand Down
Loading

0 comments on commit 745db74

Please sign in to comment.