Skip to content

Commit

Permalink
Merging Main to my branch to ensure it is current.
Browse files Browse the repository at this point in the history
Merge branch 'main' of https://github.com/ur-whitelab/md-agent into hydrogen_bonding
  • Loading branch information
brittyscience committed Nov 9, 2024
2 parents 34375c4 + 345ac20 commit b5a8d37
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 231 deletions.
2 changes: 0 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@
# OpenAI API Key
OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret

# PQA API Key to use LiteratureSearch tool (optional) -- it also requires OpenAI key
PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret

# Optional: add TogetherAI, Fireworks, or Anthropic API key here to use their models
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ repos:
- id: mixed-line-ending
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.270"
rev: "v0.7.1"
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
- repo: https://github.com/psf/black
rev: "23.3.0"
rev: "24.10.0"
hooks:
- id: black
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.3.0"
rev: "v1.13.0"
hooks:
- id: mypy
args: [--pretty, --ignore-missing-imports]
additional_dependencies: [types-requests]
- repo: https://github.com/PyCQA/isort
rev: "5.12.0"
rev: "5.13.2"
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/Yelp/detect-secrets
rev: v1.0.3
rev: v1.5.0
hooks:
- id: detect-secrets
args: [--exclude-files, ".github/workflows/"]
3 changes: 2 additions & 1 deletion mdagent/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,15 @@ def __init__(
uploaded_files=[], # user input files to add to path registry
run_id="",
use_memory=False,
paper_dir=None, # papers for pqa, relative path within repo
):
self.llm = _make_llm(model, temp, streaming)
if tools_model is None:
tools_model = model
self.tools_llm = _make_llm(tools_model, temp, streaming)

self.use_memory = use_memory
self.path_registry = PathRegistry.get_instance(ckpt_dir=ckpt_dir)
self.path_registry = PathRegistry.get_instance(ckpt_dir, paper_dir)
self.ckpt_dir = self.path_registry.ckpt_dir
self.memory = MemoryManager(self.path_registry, self.tools_llm, run_id=run_id)
self.run_id = self.memory.run_id
Expand Down
2 changes: 1 addition & 1 deletion mdagent/agent/prompt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain.prompts import PromptTemplate

structured_prompt = PromptTemplate(
input_variables=["input, context"],
input_variables=["input", "context"],
template="""
You are an expert molecular dynamics scientist, and
your task is to respond to the question or
Expand Down
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rdf_tool import RDFTool
from .analysis_tools.rgy import (
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
)
from .analysis_tools.rgy import RadiusofGyrationTool
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.secondary_structure import (
Expand Down Expand Up @@ -83,9 +79,7 @@
"PCATool",
"PPIDistance",
"ProteinName2PDBTool",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RDFTool",
"RMSDCalculator",
"Scholar2ResultLLM",
Expand Down
6 changes: 2 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .pca_tools import PCATool
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rgy import RadiusofGyrationTool
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein
Expand All @@ -20,9 +20,7 @@
"MomentOfInertia",
"PCATool",
"PPIDistance",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RMSDCalculator",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _run(self, file_id: str) -> str:
plotting_tools._find_file(file_id)
plotting_tools.process_csv()
plot_result = plotting_tools.plot_data()
if type(plot_result) == str:
if isinstance(plot_result, str):
return "Succeeded. IDs of figures created: " + plot_result
else:
return "Failed. No figures created."
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def validate_input(self, input):
)

if stride:
if type(stride) != int:
if not isinstance(stride, int):
try:
stride = int(stride)
if stride <= 0:
Expand Down
144 changes: 44 additions & 100 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, path_registry):
self.top_file = ""
self.traj_file = ""
self.traj = None
self.rgy_file = ""

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
Expand All @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str):
traj_required=True,
)

def rgy_per_frame(self, force_recompute: bool = False) -> str:
def rgy_per_frame(self) -> str:
rg_per_frame = md.compute_rg(self.traj)
self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)
rgy_id = f"rgy_{self.traj_file}"
if rgy_id in self.path_registry.list_path_names() and force_recompute is False:
print("RGY already computed, skipping re-compute")
# todo -> maybe allow re-compute & save under different id/path
else:
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}."

def rgy_average(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
Expand All @@ -66,9 +65,9 @@ def plot_rgy(self) -> str:
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if plot_name.endswith(".png"):
plot_name = plot_name.split(".png")[0]
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png"
print("plot_path", plot_path)
plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
Expand All @@ -82,101 +81,46 @@ def plot_rgy(self) -> str:
)
plt.close()
plt.clf()
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""
return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}"

path_registry: Optional[PathRegistry]
def compute_plot_return_avg(self) -> str:
rgy_per_frame = self.rgy_per_frame()
avg_rgy = self.rgy_average()
plot_rgy = self.plot_rgy()
return rgy_per_frame + plot_rgy + avg_rgy

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory.
class RadiusofGyrationTool(BaseTool):
name = "RadiusofGyrationTool"
description = """This tool calculates and plots
the radius of gyration
at each frame of a given trajectory and retuns the average.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""
topology file ID."""

path_registry: Optional[PathRegistry]
rgy: Optional[RadiusofGyration]
load_traj: bool = True

def __init__(self, path_registry):
def __init__(self, path_registry, load_traj=True):
super().__init__()
self.path_registry = path_registry
self.rgy = RadiusofGyration(path_registry)
self.load_traj = load_traj # only for testing

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
assert self.rgy is not None, "RadiusofGyration instance is not initialized"

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
if self.load_traj:
try:
self.rgy._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
return "Succeeded. " + self.rgy.compute_plot_return_avg()
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return f"Failed Computing RGY: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
Expand Down
10 changes: 6 additions & 4 deletions mdagent/tools/base_tools/preprocess_tools/uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def get_sequence_info(self, query: str, primary_accession: str) -> dict:
- 'crc64': The CRC64 hash of the protein sequence (probably not useful)
- 'md5': The MD5 hash of the protein sequence (probably not useful)
"""
seq_info = self.data = self.get_data(query, desired_field="sequence")
seq_info = self.get_data(query, desired_field="sequence")
if not seq_info:
return {}
seq_info_specific = self._match_primary_accession(seq_info, primary_accession)[
Expand Down Expand Up @@ -693,9 +693,11 @@ def get_ids(
if include_uniprotkbids:
all_ids + [entry["uniProtkbId"] for entry in ids_] if ids_ else []
accession = self.get_data(query, desired_field="accession")
all_ids + [
entry["primaryAccession"] for entry in accession
] if accession else []
(
all_ids + [entry["primaryAccession"] for entry in accession]
if accession
else []
)
if single_id:
return [all_ids[0]] if all_ids else []
return list(set(all_ids))
Expand Down
Loading

0 comments on commit b5a8d37

Please sign in to comment.