Skip to content

Commit

Permalink
combined rgy into 1 tool that does all - added unit test for tool
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Nov 4, 2024
1 parent ab25b48 commit 08cc401
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 116 deletions.
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rdf_tool import RDFTool
from .analysis_tools.rgy import (
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
)
from .analysis_tools.rgy import RadiusofGyrationTool
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.secondary_structure import (
Expand Down Expand Up @@ -80,9 +76,7 @@
"PCATool",
"PPIDistance",
"ProteinName2PDBTool",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RDFTool",
"RMSDCalculator",
"Scholar2ResultLLM",
Expand Down
6 changes: 2 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .pca_tools import PCATool
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rgy import RadiusofGyrationTool
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein
Expand All @@ -17,9 +17,7 @@
"MomentOfInertia",
"PCATool",
"PPIDistance",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RMSDCalculator",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
Expand Down
138 changes: 41 additions & 97 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, path_registry):
self.top_file = ""
self.traj_file = ""
self.traj = None
self.rgy_file = ""

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
Expand All @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str):
traj_required=True,
)

def rgy_per_frame(self, force_recompute: bool = False) -> str:
def rgy_per_frame(self) -> str:
rg_per_frame = md.compute_rg(self.traj)
self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)
rgy_id = f"rgy_{self.traj_file}"
if rgy_id in self.path_registry.list_path_names() and force_recompute is False:
print("RGY already computed, skipping re-compute")
# todo -> maybe allow re-compute & save under different id/path
else:
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}."

def rgy_average(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
Expand Down Expand Up @@ -84,99 +83,44 @@ def plot_rgy(self) -> str:
plt.clf()
return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}"

def compute_plot_return_avg(self) -> str:
rgy_per_frame = self.rgy_per_frame()
avg_rgy = self.rgy_average()
plot_rgy = self.plot_rgy()
return rgy_per_frame + plot_rgy + avg_rgy

class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory.
class RadiusofGyrationTool(BaseTool):
name = "RadiusofGyrationTool"
description = """This tool calculates and plots
the radius of gyration
at each frame of a given trajectory and retuns the average.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""
topology file ID."""

path_registry: Optional[PathRegistry]
rgy: Optional[RadiusofGyration]
load_traj: bool = True

def __init__(self, path_registry):
def __init__(self, path_registry, load_traj=True):
super().__init__()
self.path_registry = path_registry
self.rgy = RadiusofGyration(path_registry)
self.load_traj = load_traj # only for testing

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
assert self.rgy is not None, "RadiusofGyration instance is not initialized"

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
if self.load_traj:
try:
self.rgy._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
return "Succeeded. " + self.rgy.compute_plot_return_avg()
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return f"Failed Computing RGY: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
Expand Down
8 changes: 2 additions & 6 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@
PCATool,
PPIDistance,
ProteinName2PDBTool,
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
RadiusofGyrationTool,
RDFTool,
Scholar2ResultLLM,
SetUpandRunFunction,
Expand Down Expand Up @@ -99,9 +97,7 @@ def make_all_tools(
PCATool(path_registry=path_instance),
PPIDistance(path_registry=path_instance),
ProteinName2PDBTool(path_registry=path_instance),
RadiusofGyrationAverage(path_registry=path_instance),
RadiusofGyrationPerFrame(path_registry=path_instance),
RadiusofGyrationPlot(path_registry=path_instance),
RadiusofGyrationTool(path_registry=path_instance),
RDFTool(path_registry=path_instance),
SetUpandRunFunction(path_registry=path_instance),
SimulationOutputFigures(path_registry=path_instance),
Expand Down
31 changes: 30 additions & 1 deletion tests/test_analysis/test_rgy_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest

from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration
from mdagent.tools.base_tools.analysis_tools.rgy import (
RadiusofGyration,
RadiusofGyrationTool,
)


@pytest.fixture
Expand All @@ -13,6 +16,24 @@ def rgy(get_registry, loaded_cif_traj):
return rgy


@pytest.fixture
def rgy_tool(get_registry, loaded_cif_traj):
registry = get_registry("raw", False)
rgy_tool = RadiusofGyrationTool(path_registry=registry, load_traj=False)
rgy_tool.rgy.traj = loaded_cif_traj
rgy_tool.rgy.top_file = "test_top_dummy"
rgy_tool.rgy.traj_file = "test_traj_dummy"
return rgy_tool


def test_rgy_tool(rgy_tool):
output = rgy_tool._run(traj_file="test_top_dummy", top_file="test_traj_dummy")
assert "Radii of gyration saved to " in output
assert "Average radius of gyration: " in output
assert "Plot saved as: " in output
assert ".png" in output


def test_rgy_per_frame(rgy):
output = rgy.rgy_per_frame()
assert "Radii of gyration saved to " in output
Expand All @@ -27,3 +48,11 @@ def test_plot_rgy(rgy):
output = rgy.plot_rgy()
assert "Plot saved as: " in output
assert ".png" in output


def test_compute_plot_return_avg_rgy(rgy):
output = rgy.compute_plot_return_avg()
assert "Radii of gyration saved to " in output
assert "Average radius of gyration: " in output
assert "Plot saved as: " in output
assert ".png" in output

0 comments on commit 08cc401

Please sign in to comment.