Skip to content

Commit

Permalink
Rmsd (#22)
Browse files Browse the repository at this point in the history
* rmsd tool draft

* fixed name2pdb to catch error properly

* tested and optimized name2pdb and rmsd tools

* updated rmsd notebook results

* updated test.yml

* updated name2pdb test

* separated ppi distance and rmsd tools (expanded)

* fixed minor error in FixPDBFile tool

* fixed multi-input tool for ppi distance

* changed error handling for ppi distance tool

* fixed rmsd tools

* reorganized tools to prepare for multiagent merge

* fixed make_llm

* fixed makellm and added paths_registry.json to ignore list

* fixed python_repl import

* fixed langchain imports

* updated gpt4 to gpt4 turbo
  • Loading branch information
qcampbel authored Nov 27, 2023
1 parent c9a7120 commit 4a017e0
Show file tree
Hide file tree
Showing 27 changed files with 681 additions and 164 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ jobs:
environment-file: environment.yaml
python-version: ${{ matrix.python-version }}
auto-activate-base: true
- name: Install [openmm pdbfixer] with conda
- name: Install openmm pdbfixer mdanalysis with conda
shell: bash -l {0}
run: |
conda install -c conda-forge openmm pdbfixer
conda install -c conda-forge openmm pdbfixer mdanalysis
- name: Install dependencies
shell: bash -l {0}
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,4 @@ dmypy.json

# generated files for testing
*registry.json
*paths_registry.json
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ It's built using Langchain and uses a collection of tools to set up and execute
To use the OpenMM features in the agent, please set up a conda environment, following these steps.
- Create conda environment: `conda env create -n mdagent -f environment.yaml`
- Activate your environment: `conda activate mdagent`
- Install the necessary conda dependencies: `conda install -c conda-forge openmm pdbfixer`

If you already have a conda environment, you can install the necessary dependencies with the following steps.
- Install the necessary conda dependencies: `conda install -c conda-forge openmm pdbfixer mdanalysis`


## Installation
Expand Down
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ channels:
dependencies:
- openmm >= 7.6
- pdbfixer >= 1.5
- mdanalysis
- pip
- pip:
- flake8
Expand Down
33 changes: 2 additions & 31 deletions mdagent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,3 @@
from .agent import MDAgent, make_tools
from .tools import (
CheckDirectoryFiles,
CleaningTools,
PlanBVisualizationTool,
Scholar2ResultLLM,
SetUpAndRunTool,
SimulationFunctions,
SimulationOutputFigures,
VisFunctions,
VisualizationToolRender,
get_pdb,
)
from .utils import PathRegistry
from .agent import MDAgent

__all__ = [
"MDAgent",
"Scholar2ResultLLM",
"Name2PDBTool",
"SimulationFunctions",
"make_tools",
"VisFunctions",
"CleaningTools",
"MDAgent",
"VisualizationToolRender",
"CheckDirectoryFiles",
"PlanBVisualizationTool",
"SetUpAndRunTool",
"PathRegistry",
"SimulationOutputFigures",
"get_pdb",
]
__all__ = ["MDAgent"]
3 changes: 1 addition & 2 deletions mdagent/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .agent import MDAgent
from .tools import make_tools

__all__ = ["MDAgent", "make_tools"]
__all__ = ["MDAgent"]
53 changes: 16 additions & 37 deletions mdagent/agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,19 @@
import langchain
from dotenv import load_dotenv
from langchain.agents import AgentType, initialize_agent
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

from .prompt import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, SUFFIX
from .tools import make_tools
from mdagent.agent.prompt import FORMAT_INSTRUCTIONS, QUESTION_PROMPT, SUFFIX
from mdagent.tools import make_all_tools
from mdagent.utils import _make_llm

load_dotenv()


def _make_llm(model, temp, verbose):
if model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
llm = langchain.chat_models.ChatOpenAI(
temperature=temp,
model_name=model,
request_timeout=1000,
streaming=True if verbose else False,
callbacks=[StreamingStdOutCallbackHandler()] if verbose else [None],
)
elif model.startswith("text-"):
llm = langchain.OpenAI(
temperature=temp,
model_name=model,
streaming=True if verbose else False,
callbacks=[StreamingStdOutCallbackHandler()] if verbose else [None],
)
else:
raise ValueError(f"Invalid model name: {model}")
return llm


class MDAgent:
def __init__(
self,
tools=None,
model="gpt-4",
tools_model="gpt-4",
model="gpt-4-1106-preview", # current name for gpt-4 turbo
tools_model="gpt-4-1106-preview",
temp=0.1,
max_iterations=40,
api_key=None,
Expand All @@ -44,7 +22,7 @@ def __init__(
self.llm = _make_llm(model, temp, verbose)
if tools is None:
tools_llm = _make_llm(tools_model, temp, verbose)
tools = make_tools(tools_llm, verbose=verbose)
tools = make_all_tools(tools_llm, verbose=verbose)

# Initialize agent
# self.agent_executor = RetryAgentExecutor.from_agent_and_tools(
Expand All @@ -56,6 +34,10 @@ def __init__(
# format_instructions=FORMAT_INSTRUCTIONS,
# question_prompt=QUESTION_PROMPT,
# ),
# verbose=True,
# max_iterations=max_iterations,
# return_intermediate_steps=True,
# )
self.agent_executor = initialize_agent(
tools,
self.llm,
Expand All @@ -65,20 +47,17 @@ def __init__(
question_prompt=QUESTION_PROMPT,
return_intermediate_steps=True,
max_iterations=max_iterations,
verbose=verbose,
)
# verbose=True,
# max_iterations=max_iterations,
# return_intermediate_steps=True,
# )

def run(self, prompt):
outputs = self.agent_executor({"input": prompt})
# Parse long output (with intermediate steps)
intermed = outputs["intermediate_steps"]
# intermed = outputs["intermediate_steps"]

final = ""
for step in intermed:
final += f"Thought: {step[0].log}\n" f"Observation: {step[1]}\n"
final += f"Final Answer: {outputs['output']}"
# final = ""
# for step in intermed:
# final += f"Thought: {step[0].log}\n" f"Observation: {step[1]}\n"
final = outputs["output"]

return final
43 changes: 2 additions & 41 deletions mdagent/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,3 @@
from .clean_tools import (
AddHydrogensCleaningTool,
CleaningTools,
RemoveWaterCleaningTool,
SpecializedCleanTool,
)
from .git_issues_tool import SerpGitTool
from .md_util_tools import Name2PDBTool, get_pdb
from .pdb_tools import PackMolTool
from .plot_tools import SimulationOutputFigures
from .registry_tools import ListRegistryPaths, MapPath2Name
from .search_tools import Scholar2ResultLLM
from .setup_and_run import SetUpAndRunTool, SimulationFunctions
from .vis_tools import (
CheckDirectoryFiles,
PlanBVisualizationTool,
VisFunctions,
VisualizationToolRender,
)
from .maketools import make_all_tools

__all__ = [
"Scholar2ResultLLM",
"VisFunctions",
"CleaningTools",
"SimulationFunctions",
"VisualizationToolRender",
"CheckDirectoryFiles",
"PlanBVisualizationTool",
"get_pdb",
"SpecializedCleanTool",
"RemoveWaterCleaningTool",
"AddHydrogensCleaningTool",
"SetUpAndRunTool",
"CheckDirectoryFiles",
"PlanBVisualizationTool",
"SimulationOutputFigures",
"SerpGitTool",
"PackMolTool",
"ListRegistryPaths",
"MapPath2Name",
"Name2PDBTool",
]
__all__ = ["make_all_tools"]
45 changes: 45 additions & 0 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from .clean_tools import (
AddHydrogensCleaningTool,
CleaningTools,
RemoveWaterCleaningTool,
SpecializedCleanTool,
)
from .git_issues_tool import SerpGitTool
from .md_util_tools import Name2PDBTool, get_pdb
from .pdb_tools import PackMolTool
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .registry_tools import ListRegistryPaths, MapPath2Name
from .rmsd_tools import RMSDCalculator
from .search_tools import Scholar2ResultLLM
from .setup_and_run import InstructionSummary, SetUpAndRunTool, SimulationFunctions
from .vis_tools import (
CheckDirectoryFiles,
PlanBVisualizationTool,
VisFunctions,
VisualizationToolRender,
)

__all__ = [
"AddHydrogensCleaningTool",
"CheckDirectoryFiles",
"CleaningTools",
"InstructionSummary",
"ListRegistryPaths",
"MapPath2Name",
"Name2PDBTool",
"PackMolTool",
"PPIDistance",
"PlanBVisualizationTool",
"RMSDCalculator",
"RemoveWaterCleaningTool",
"Scholar2ResultLLM",
"SerpGitTool",
"SetUpAndRunTool",
"SimulationFunctions",
"SimulationOutputFigures",
"SpecializedCleanTool",
"VisFunctions",
"VisualizationToolRender",
"get_pdb",
]
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
from typing import List, Optional

import langchain
import requests
import tiktoken
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools import BaseTool
from serpapi import GoogleSearch

from mdagent.utils import _make_llm


class GitToolFunctions:
"""Class to store the functions of the tool."""

"""chain that can be used the tools for summarization or classification"""

llm_ = langchain.chat_models.ChatOpenAI(
temperature=0.05,
model_name="gpt-3.5-turbo-16k",
request_timeout=1000,
max_tokens=2500,
llm_ = _make_llm(
model="gpt-3.5-turbo-16k", temp=0.05, verbose=False, max_tokens=2500
)

def _prompt_summary(self, query: str, output: str, llm: BaseLanguageModel = llm_):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def get_pdb(query_string, PathRegistry):
filetype = "cif"
else:
filetype = "pdb"

if "result_set" in r.json() and len(r.json()["result_set"]) > 0:
pdbid = r.json()["result_set"][0]["identifier"]
print(f"PDB file found with this ID: {pdbid}")
Expand All @@ -36,7 +35,6 @@ def get_pdb(query_string, PathRegistry):
filename = f"{pdbid}.{filetype}"
with open(filename, "w") as file:
file.write(pdb.text)
file.close()
print(f"{filename} is created.")
file_description = f"PDB file downloaded from RSCB, PDB ID: {pdbid}"
PathRegistry.map_path(filename, filename, file_description)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,7 @@ def validate_input(cls, values: Union[str, Dict[str, Any]]) -> Dict:
class FixPDBFile(BaseTool):
name: str = "PDB File Fixer"
description: str = "Fixes PDB files columns if needed"
args_schema = PDBFilesFixInp
args_schema: Type[BaseModel] = PDBFilesFixInp

path_registry: typing.Optional[PathRegistry]

Expand Down
File renamed without changes.
70 changes: 70 additions & 0 deletions mdagent/tools/base_tools/ppi_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Optional, Type

import MDAnalysis as mda
import MDAnalysis.analysis.distances as mda_dist
import numpy as np
from langchain.tools import BaseTool
from pydantic import BaseModel, Field


def ppi_distance(pdb_file, binding_site="protein"):
"""
Calculates minimum heavy-atom distance between peptide (assumed to be
smallest chain) and protein. Returns average distance between these two.
Can specify binding site if given (optional)
Can work with any protein-protein interaction (PPI)
"""
# load and find smallest chain
u = mda.Universe(pdb_file)
peptide = None
for chain in u.segments:
if peptide is None or len(chain.residues) < len(peptide):
peptide = chain.residues
protein = u.select_atoms(
f"({binding_site}) and not segid {peptide.segids[0]} and not name H*"
)
peptide = peptide.atoms.select_atoms("not name H*")
all_d = []
for r in peptide.residues:
distances = mda_dist.distance_array(r.atoms.positions, protein.positions)
# get row, column of minimum distance
i, j = np.unravel_index(distances.argmin(), distances.shape)
all_d.append(distances[i, j])
avg_dist = np.mean(all_d)
return avg_dist


class PPIDistanceInputSchema(BaseModel):
pdb_file: str = Field(
description="file with .pdb extension containing protein-protein interaction"
)
binding_site: Optional[str] = Field(
description="""a list of selected residues as the binding site
of the protein using MDAnalysis selection syntax."""
)


class PPIDistance(BaseTool):
name: str = "ppi_distance"
description: str = """Useful for calculating minimum heavy-atom distance
between peptide and protein. First, make sure you have valid PDB file with
any protein-protein interaction."""
args_schema: Type[BaseModel] = PPIDistanceInputSchema

def _run(self, pdb_file: str, binding_site: str = "protein"):
if not pdb_file.endswith(".pdb"):
return "Error with input: PDB file must have .pdb extension"
try:
avg_dist = ppi_distance(pdb_file, binding_site=binding_site)
except ValueError as e:
return (
f"ValueError: {e}. \nMake sure to provide valid PBD "
"file and binding site using MDAnalysis selection syntax."
)
except Exception as e:
return f"Something went wrong. {type(e).__name__}: {e}"
return f"{avg_dist}\n"

def _arun(self, pdb_file: str, binding_site: str = "protein"):
raise NotImplementedError("This tool does not support async")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from langchain.tools import BaseTool

from ..utils import PathRegistry
from mdagent.utils import PathRegistry


class MapPath2Name(BaseTool):
Expand Down
Loading

0 comments on commit 4a017e0

Please sign in to comment.