Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge from main to experiments #164

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@
# OpenAI API Key
OPENAI_API_KEY=YOUR_OPENAI_API_KEY_GOES_HERE # pragma: allowlist secret

# PQA API Key to use LiteratureSearch tool (optional) -- it also requires OpenAI key
PQA_API_KEY=YOUR_PQA_API_KEY_GOES_HERE # pragma: allowlist secret

# Optional: add TogetherAI, Fireworks, or Anthropic API key here to use their models
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ repos:
- id: mixed-line-ending
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.0.270"
rev: "v0.7.1"
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
- repo: https://github.com/psf/black
rev: "23.3.0"
rev: "24.10.0"
hooks:
- id: black
language_version: python3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.3.0"
rev: "v1.13.0"
hooks:
- id: mypy
args: [--pretty, --ignore-missing-imports]
additional_dependencies: [types-requests]
- repo: https://github.com/PyCQA/isort
rev: "5.12.0"
rev: "5.13.2"
hooks:
- id: isort
args: [--profile=black]
- repo: https://github.com/Yelp/detect-secrets
rev: v1.0.3
rev: v1.5.0
hooks:
- id: detect-secrets
args: [--exclude-files, ".github/workflows/"]
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rdf_tool import RDFTool
from .analysis_tools.rgy import (
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
)
from .analysis_tools.rgy import RadiusofGyrationTool
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.secondary_structure import (
Expand Down Expand Up @@ -80,9 +76,7 @@
"PCATool",
"PPIDistance",
"ProteinName2PDBTool",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RDFTool",
"RMSDCalculator",
"Scholar2ResultLLM",
Expand Down
6 changes: 2 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .pca_tools import PCATool
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rgy import RadiusofGyrationTool
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein
Expand All @@ -17,9 +17,7 @@
"MomentOfInertia",
"PCATool",
"PPIDistance",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RMSDCalculator",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _run(self, file_id: str) -> str:
plotting_tools._find_file(file_id)
plotting_tools.process_csv()
plot_result = plotting_tools.plot_data()
if type(plot_result) == str:
if isinstance(plot_result, str):
return "Succeeded. IDs of figures created: " + plot_result
else:
return "Failed. No figures created."
Expand Down
2 changes: 1 addition & 1 deletion mdagent/tools/base_tools/analysis_tools/rdf_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def validate_input(self, input):
)

if stride:
if type(stride) != int:
if not isinstance(stride, int):
try:
stride = int(stride)
if stride <= 0:
Expand Down
145 changes: 44 additions & 101 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, path_registry):
self.top_file = ""
self.traj_file = ""
self.traj = None
self.rgy_file = ""

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
Expand All @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str):
traj_required=True,
)

def rgy_per_frame(self, force_recompute: bool = False) -> str:
def rgy_per_frame(self) -> str:
rg_per_frame = md.compute_rg(self.traj)
self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)
rgy_id = f"rgy_{self.traj_file}"
if rgy_id in self.path_registry.list_path_names() and force_recompute is False:
print("RGY already computed, skipping re-compute")
# todo -> maybe allow re-compute & save under different id/path
else:
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}."

def rgy_average(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
Expand All @@ -66,9 +65,8 @@ def plot_rgy(self) -> str:
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if plot_name.endswith(".png"):
plot_name = plot_name.split(".png")[0]
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png"
print("plot_path", plot_path)
plt.plot(rg_per_frame)
plt.xlabel("Frame")
Expand All @@ -78,106 +76,51 @@ def plot_rgy(self) -> str:
plt.savefig(f"{plot_path}")
self.path_registry.map_path(
plot_id,
plot_path + ".png",
plot_path,
description=f"Plot of radii of gyration over time for {self.traj_file}",
)
plt.close()
plt.clf()
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""
return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}"

path_registry: Optional[PathRegistry]
def compute_plot_return_avg(self) -> str:
rgy_per_frame = self.rgy_per_frame()
avg_rgy = self.rgy_average()
plot_rgy = self.plot_rgy()
return rgy_per_frame + plot_rgy + avg_rgy

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory.
class RadiusofGyrationTool(BaseTool):
name = "RadiusofGyrationTool"
description = """This tool calculates and plots
the radius of gyration
at each frame of a given trajectory and retuns the average.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""
topology file ID."""

path_registry: Optional[PathRegistry]
rgy: Optional[RadiusofGyration]
load_traj: bool = True

def __init__(self, path_registry):
def __init__(self, path_registry, load_traj=True):
super().__init__()
self.path_registry = path_registry
self.rgy = RadiusofGyration(path_registry)
self.load_traj = load_traj # only for testing

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
assert self.rgy is not None, "RadiusofGyration instance is not initialized"

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
if self.load_traj:
try:
self.rgy._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
return "Succeeded. " + self.rgy.compute_plot_return_avg()
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return f"Failed Computing RGY: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
Expand Down
10 changes: 6 additions & 4 deletions mdagent/tools/base_tools/preprocess_tools/uniprot.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,7 +475,7 @@ def get_sequence_info(self, query: str, primary_accession: str) -> dict:
- 'crc64': The CRC64 hash of the protein sequence (probably not useful)
- 'md5': The MD5 hash of the protein sequence (probably not useful)
"""
seq_info = self.data = self.get_data(query, desired_field="sequence")
seq_info = self.get_data(query, desired_field="sequence")
if not seq_info:
return {}
seq_info_specific = self._match_primary_accession(seq_info, primary_accession)[
Expand Down Expand Up @@ -693,9 +693,11 @@ def get_ids(
if include_uniprotkbids:
all_ids + [entry["uniProtkbId"] for entry in ids_] if ids_ else []
accession = self.get_data(query, desired_field="accession")
all_ids + [
entry["primaryAccession"] for entry in accession
] if accession else []
(
all_ids + [entry["primaryAccession"] for entry in accession]
if accession
else []
)
if single_id:
return [all_ids[0]] if all_ids else []
return list(set(all_ids))
Expand Down
Loading
Loading