Skip to content

Commit

Permalink
Packmol tool, and registry update to new version (#41)
Browse files Browse the repository at this point in the history
* Packmol tool, and registry update to new version

* trying with typing to avoid error with type descriptions

* add typing.Type to be useful in python8

* correcting pr comments and updated pdb tool

* test file is now .pdb
  • Loading branch information
Jgmedina95 authored Oct 31, 2023
1 parent 9312250 commit 7bb426c
Show file tree
Hide file tree
Showing 22 changed files with 2,285 additions and 35,911 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,11 @@ dmypy.json

# testing files generated
*.txt.json

# pdb cif files
*.pdb
*.cif
!3pqr.cif

# generated files for testing
*registry.json
7 changes: 1 addition & 6 deletions mdagent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
from .tools import (
CheckDirectoryFiles,
CleaningTools,
ListRegistryPaths,
MapPath2Name,
Name2PDBTool,
PathRegistry,
PlanBVisualizationTool,
Scholar2ResultLLM,
SetUpAndRunTool,
Expand All @@ -15,6 +11,7 @@
VisualizationToolRender,
get_pdb,
)
from .utils import PathRegistry

__all__ = [
"MDAgent",
Expand All @@ -29,9 +26,7 @@
"CheckDirectoryFiles",
"PlanBVisualizationTool",
"SetUpAndRunTool",
"ListRegistryPaths",
"PathRegistry",
"MapPath2Name",
"SimulationOutputFigures",
"get_pdb",
]
34 changes: 22 additions & 12 deletions mdagent/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import langchain
from dotenv import load_dotenv
from langchain.agents import AgentType, initialize_agent
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from rmrkl import ChatZeroShotAgent, RetryAgentExecutor

from .prompt import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, SUFFIX
from .tools import make_tools
Expand Down Expand Up @@ -47,19 +47,29 @@ def __init__(
tools = make_tools(tools_llm, verbose=verbose)

# Initialize agent
self.agent_executor = RetryAgentExecutor.from_agent_and_tools(
tools=tools,
agent=ChatZeroShotAgent.from_llm_and_tools(
self.llm,
tools=tools,
suffix=SUFFIX,
format_instructions=FORMAT_INSTRUCTIONS,
question_prompt=QUESTION_PROMPT,
),
verbose=True,
max_iterations=max_iterations,
# self.agent_executor = RetryAgentExecutor.from_agent_and_tools(
# tools=tools,
# agent=ChatZeroShotAgent.from_llm_and_tools(
# self.llm,
# tools=tools,
# suffix=SUFFIX,
# format_instructions=FORMAT_INSTRUCTIONS,
# question_prompt=QUESTION_PROMPT,
# ),
self.agent_executor = initialize_agent(
tools,
self.llm,
agent=AgentType.OPENAI_FUNCTIONS,
suffix=SUFFIX,
format_instructions=FORMAT_INSTRUCTIONS,
question_prompt=QUESTION_PROMPT,
return_intermediate_steps=True,
max_iterations=max_iterations,
)
# verbose=True,
# max_iterations=max_iterations,
# return_intermediate_steps=True,
# )

def run(self, prompt):
outputs = self.agent_executor({"input": prompt})
Expand Down
5 changes: 3 additions & 2 deletions mdagent/agent/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ListRegistryPaths,
MapPath2Name,
Name2PDBTool,
PathRegistry,
PackMolTool,
PlanBVisualizationTool,
RemoveWaterCleaningTool,
Scholar2ResultLLM,
Expand All @@ -19,6 +19,7 @@
SpecializedCleanTool,
VisualizationToolRender,
)
from ..utils import PathRegistry


def make_tools(llm: BaseLanguageModel, verbose=False):
Expand Down Expand Up @@ -46,12 +47,12 @@ def make_tools(llm: BaseLanguageModel, verbose=False):
SpecializedCleanTool(path_registry=path_instance),
RemoveWaterCleaningTool(path_registry=path_instance),
AddHydrogensCleaningTool(path_registry=path_instance),
PackMolTool(path_registry=path_instance),
]
# add serpapi tool
serp_key = os.getenv("SERP_API_KEY")
if serp_key:
all_tools.append(SerpGitTool(serp_key))

# add literature search tool
# Get the api keys
pqa_key = os.getenv("PQA_API_KEY")
Expand Down
11 changes: 6 additions & 5 deletions mdagent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
)
from .git_issues_tool import SerpGitTool
from .md_util_tools import Name2PDBTool, get_pdb
from .pdb_tools import PackMolTool
from .plot_tools import SimulationOutputFigures
from .registry import ListRegistryPaths, MapPath2Name, PathRegistry
from .registry_tools import ListRegistryPaths, MapPath2Name
from .search_tools import Scholar2ResultLLM
from .setup_and_run import SetUpAndRunTool, SimulationFunctions
from .vis_tools import (
Expand All @@ -25,10 +26,6 @@
"VisualizationToolRender",
"CheckDirectoryFiles",
"PlanBVisualizationTool",
"ListRegistryPaths",
"PathRegistry",
"MapPath2Name",
"Name2PDBTool",
"get_pdb",
"SpecializedCleanTool",
"RemoveWaterCleaningTool",
Expand All @@ -38,4 +35,8 @@
"PlanBVisualizationTool",
"SimulationOutputFigures",
"SerpGitTool",
"PackMolTool",
"ListRegistryPaths",
"MapPath2Name",
"Name2PDBTool",
]
2 changes: 1 addition & 1 deletion mdagent/tools/clean_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from openmm.app import PDBFile, PDBxFile
from pdbfixer import PDBFixer

from .registry import PathRegistry
from mdagent.utils import PathRegistry


class CleaningTools:
Expand Down
43 changes: 25 additions & 18 deletions mdagent/tools/md_util_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import requests
from langchain.tools import BaseTool

from .registry import PathRegistry
from mdagent.utils import PathRegistry


def get_pdb(query_string, PathRegistry):
"""
Search RSCB's protein data bank using the given query string
and return the path to pdb file
and return the path to pdb file in either CIF or PDB format
"""

url = "https://search.rcsb.org/rcsbsearch/v2/query?json={search-request}"
query = {
"query": {
Expand All @@ -23,31 +22,39 @@ def get_pdb(query_string, PathRegistry):
}
r = requests.post(url, json=query)
if r.status_code == 204:
return "No Content Error: PDB ID not found for this substance."
elif "result_set" in r.json() and len(r.json()["result_set"]) > 0:
return None
if "cif" in query_string or "CIF" in query_string:
filetype = "cif"
else:
filetype = "pdb"

if "result_set" in r.json() and len(r.json()["result_set"]) > 0:
pdbid = r.json()["result_set"][0]["identifier"]
url = f"https://files.rcsb.org/download/{pdbid}.cif"
print(f"PDB file found with this ID: {pdbid}")
url = f"https://files.rcsb.org/download/{pdbid}.{filetype}"
pdb = requests.get(url)
filename = f"{pdbid}.cif"
filename = f"{pdbid}.{filetype}"
with open(filename, "w") as file:
file.write(pdb.text)
# add filename to registry
file_description = "PDB file downloaded from RSCB"
file.close()
print(f"{filename} is created.")
file_description = f"PDB file downloaded from RSCB, PDB ID: {pdbid}"
PathRegistry.map_path(filename, filename, file_description)
return filename
return None


class Name2PDBTool(BaseTool):
name = "PDBFileDownloader"
description = """This tool downloads PDB (Protein Data Bank) or
CIF (Crystallographic Information File)
files using commercial chemical names.
It’s ideal for situations where you
need to directly retrieve these files
using a chemical’s commercial name.
Input: Commercial name of the chemical
CIF (Crystallographic Information File) files using
commercial chemical names. It’s ideal for situations where
you need to directly retrieve these file using a chemical’s
commercial name. When a specific file type, either PDB or CIF,
is requested, add file type to the query string with space.
Input: Commercial name of the chemical or file without
file extension
Output: Corresponding PDB or CIF file"""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry: Optional[PathRegistry]):
Expand All @@ -61,12 +68,12 @@ def _run(self, query: str) -> str:
return "Path registry not initialized"
pdb = get_pdb(query, self.path_registry)
if pdb is None:
return "Name2PDB tool failed to download the PDB file."
return "Name2PDB tool failed to find and download PDB file."
else:
return f"Name2PDB tool successfully downloaded the PDB file: {pdb}"
except Exception as e:
return f"Something went wrong. {e}"

async def _arun(self, query) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("Name2PDB does not support async")
raise NotImplementedError("this tool does not support async")
Loading

0 comments on commit 7bb426c

Please sign in to comment.