diff --git a/mdagent/tools/base_tools/__init__.py b/mdagent/tools/base_tools/__init__.py index 23a1fd21..04c4fa35 100644 --- a/mdagent/tools/base_tools/__init__.py +++ b/mdagent/tools/base_tools/__init__.py @@ -4,11 +4,7 @@ from .analysis_tools.plot_tools import SimulationOutputFigures from .analysis_tools.ppi_tools import PPIDistance from .analysis_tools.rdf_tool import RDFTool -from .analysis_tools.rgy import ( - RadiusofGyrationAverage, - RadiusofGyrationPerFrame, - RadiusofGyrationPlot, -) +from .analysis_tools.rgy import RadiusofGyrationTool from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF from .analysis_tools.sasa import SolventAccessibleSurfaceArea from .analysis_tools.secondary_structure import ( @@ -80,9 +76,7 @@ "PCATool", "PPIDistance", "ProteinName2PDBTool", - "RadiusofGyrationAverage", - "RadiusofGyrationPerFrame", - "RadiusofGyrationPlot", + "RadiusofGyrationTool", "RDFTool", "RMSDCalculator", "Scholar2ResultLLM", diff --git a/mdagent/tools/base_tools/analysis_tools/__init__.py b/mdagent/tools/base_tools/analysis_tools/__init__.py index 81562527..26d58784 100644 --- a/mdagent/tools/base_tools/analysis_tools/__init__.py +++ b/mdagent/tools/base_tools/analysis_tools/__init__.py @@ -3,7 +3,7 @@ from .pca_tools import PCATool from .plot_tools import SimulationOutputFigures from .ppi_tools import PPIDistance -from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot +from .rgy import RadiusofGyrationTool from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF from .sasa import SolventAccessibleSurfaceArea from .vis_tools import VisFunctions, VisualizeProtein @@ -17,9 +17,7 @@ "MomentOfInertia", "PCATool", "PPIDistance", - "RadiusofGyrationAverage", - "RadiusofGyrationPerFrame", - "RadiusofGyrationPlot", + "RadiusofGyrationTool", "RMSDCalculator", "SimulationOutputFigures", "SolventAccessibleSurfaceArea", diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 9d2de240..5976d9f5 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -14,6 +14,7 @@ def __init__(self, path_registry): self.top_file = "" self.traj_file = "" self.traj = None + self.rgy_file = "" def _load_traj(self, top_file: str, traj_file: str): self.traj_file = traj_file @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str): traj_required=True, ) - def rgy_per_frame(self, force_recompute: bool = False) -> str: + def rgy_per_frame(self) -> str: rg_per_frame = md.compute_rg(self.traj) self.rgy_file = ( f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv" ) rgy_id = f"rgy_{self.traj_file}" - if rgy_id in self.path_registry.list_path_names() and force_recompute is False: - print("RGY already computed, skipping re-compute") - # todo -> maybe allow re-compute & save under different id/path - else: - np.savetxt( - self.rgy_file, - rg_per_frame, - delimiter=",", - header="Radius of Gyration (nm)", - ) - self.path_registry.map_path( - f"rgy_{self.traj_file}", - self.rgy_file, - description=f"Radii of gyration per frame for {self.traj_file}", - ) + np.savetxt( + self.rgy_file, + rg_per_frame, + delimiter=",", + header="Radius of Gyration (nm)", + ) + self.path_registry.map_path( + f"rgy_{self.traj_file}", + self.rgy_file, + description=f"Radii of gyration per frame for {self.traj_file}", + ) return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}." def rgy_average(self) -> str: - _ = self.rgy_per_frame() + if not self.rgy_file: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) avg_rg = rg_per_frame.mean() return f"Average radius of gyration: {avg_rg:.2f} nm" def plot_rgy(self) -> str: - _ = self.rgy_per_frame() + if not self.rgy_file: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) fig_analysis = f"rgy_{self.traj_file}" plot_name = self.path_registry.write_file_name( @@ -84,99 +83,44 @@ def plot_rgy(self) -> str: plt.clf() return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}" + def compute_plot_return_avg(self) -> str: + rgy_per_frame = self.rgy_per_frame() + avg_rgy = self.rgy_average() + plot_rgy = self.plot_rgy() + return rgy_per_frame + plot_rgy + avg_rgy -class RadiusofGyrationAverage(BaseTool): - name = "RadiusofGyrationAverage" - description = """This tool calculates the average radius of gyration - for a trajectory. Give this tool BOTH the trajectory file ID and the - topology file ID.""" - - path_registry: Optional[PathRegistry] - - def __init__(self, path_registry): - super().__init__() - self.path_registry = path_registry - def _run(self, traj_file: str, top_file: str) -> str: - """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" - try: - return "Succeeded. " + RGY.rgy_average() - except ValueError as e: - return f"Failed. ValueError: {e}" - except Exception as e: - return f"Failed. {type(e).__name__}: {e}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class RadiusofGyrationPerFrame(BaseTool): - name = "RadiusofGyrationPerFrame" - description = """This tool calculates the radius of gyration - at each frame of a given trajectory. +class RadiusofGyrationTool(BaseTool): + name = "RadiusofGyrationTool" + description = """This tool calculates and plots + the radius of gyration + at each frame of a given trajectory and retuns the average. Give this tool BOTH the trajectory file ID and the - topology file ID. - The tool will save the radii of gyration to a csv file and - map it to the registry.""" + topology file ID.""" path_registry: Optional[PathRegistry] + rgy: Optional[RadiusofGyration] + load_traj: bool = True - def __init__(self, path_registry): + def __init__(self, path_registry, load_traj=True): super().__init__() self.path_registry = path_registry + self.rgy = RadiusofGyration(path_registry) + self.load_traj = load_traj # only for testing def _run(self, traj_file: str, top_file: str) -> str: """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" - try: - return "Succeeded. " + RGY.rgy_per_frame() - except ValueError as e: - return f"Failed. ValueError: {e}" - except Exception as e: - return f"Failed. {type(e).__name__}: {e}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class RadiusofGyrationPlot(BaseTool): - name = "RadiusofGyrationPlot" - description = """This tool calculates the radius of gyration - at each frame of a given trajectory file and plots it. - Give this tool BOTH the trajectory file ID and the - topology file ID. - The tool will save the plot to a png file and map it to the registry.""" - - path_registry: Optional[PathRegistry] + assert self.rgy is not None, "RadiusofGyration instance is not initialized" - def __init__(self, path_registry): - super().__init__() - self.path_registry = path_registry - - def _run(self, traj_file: str, top_file: str) -> str: - """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" + if self.load_traj: + try: + self.rgy._load_traj(top_file=top_file, traj_file=traj_file) + except Exception as e: + return f"Error loading traj: {e}" try: - return "Succeeded. " + RGY.plot_rgy() - except ValueError as e: - return f"Failed. ValueError: {e}" + return "Succeeded. " + self.rgy.compute_plot_return_avg() except Exception as e: - return f"Failed. {type(e).__name__}: {e}" + return f"Failed Computing RGY: {e}" async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index b690d615..f831708e 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -46,9 +46,7 @@ PCATool, PPIDistance, ProteinName2PDBTool, - RadiusofGyrationAverage, - RadiusofGyrationPerFrame, - RadiusofGyrationPlot, + RadiusofGyrationTool, RDFTool, Scholar2ResultLLM, SetUpandRunFunction, @@ -99,9 +97,7 @@ def make_all_tools( PCATool(path_registry=path_instance), PPIDistance(path_registry=path_instance), ProteinName2PDBTool(path_registry=path_instance), - RadiusofGyrationAverage(path_registry=path_instance), - RadiusofGyrationPerFrame(path_registry=path_instance), - RadiusofGyrationPlot(path_registry=path_instance), + RadiusofGyrationTool(path_registry=path_instance), RDFTool(path_registry=path_instance), SetUpandRunFunction(path_registry=path_instance), SimulationOutputFigures(path_registry=path_instance), diff --git a/tests/test_analysis/test_rgy_tool.py b/tests/test_analysis/test_rgy_tool.py index 57f815cd..c131d5be 100644 --- a/tests/test_analysis/test_rgy_tool.py +++ b/tests/test_analysis/test_rgy_tool.py @@ -1,6 +1,9 @@ import pytest -from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration +from mdagent.tools.base_tools.analysis_tools.rgy import ( + RadiusofGyration, + RadiusofGyrationTool, +) @pytest.fixture @@ -13,6 +16,24 @@ def rgy(get_registry, loaded_cif_traj): return rgy +@pytest.fixture +def rgy_tool(get_registry, loaded_cif_traj): + registry = get_registry("raw", False) + rgy_tool = RadiusofGyrationTool(path_registry=registry, load_traj=False) + rgy_tool.rgy.traj = loaded_cif_traj + rgy_tool.rgy.top_file = "test_top_dummy" + rgy_tool.rgy.traj_file = "test_traj_dummy" + return rgy_tool + + +def test_rgy_tool(rgy_tool): + output = rgy_tool._run(traj_file="test_top_dummy", top_file="test_traj_dummy") + assert "Radii of gyration saved to " in output + assert "Average radius of gyration: " in output + assert "Plot saved as: " in output + assert ".png" in output + + def test_rgy_per_frame(rgy): output = rgy.rgy_per_frame() assert "Radii of gyration saved to " in output @@ -27,3 +48,11 @@ def test_plot_rgy(rgy): output = rgy.plot_rgy() assert "Plot saved as: " in output assert ".png" in output + + +def test_compute_plot_return_avg_rgy(rgy): + output = rgy.compute_plot_return_avg() + assert "Radii of gyration saved to " in output + assert "Average radius of gyration: " in output + assert "Plot saved as: " in output + assert ".png" in output