From ab25b487cae7ded2feb7b0d07fb0d286aeb812ea Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Mon, 4 Nov 2024 11:04:38 -0800 Subject: [PATCH 1/3] rgy png fix --- mdagent/tools/base_tools/analysis_tools/rgy.py | 7 +++---- tests/test_analysis/test_rgy_tool.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 61ad7698..9d2de240 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -66,9 +66,8 @@ def plot_rgy(self) -> str: plot_id = self.path_registry.get_fileid( file_name=plot_name, type=FileType.FIGURE ) - if plot_name.endswith(".png"): - plot_name = plot_name.split(".png")[0] plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}" + plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png" print("plot_path", plot_path) plt.plot(rg_per_frame) plt.xlabel("Frame") @@ -78,12 +77,12 @@ def plot_rgy(self) -> str: plt.savefig(f"{plot_path}") self.path_registry.map_path( plot_id, - plot_path + ".png", + plot_path, description=f"Plot of radii of gyration over time for {self.traj_file}", ) plt.close() plt.clf() - return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}" + return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}" class RadiusofGyrationAverage(BaseTool): diff --git a/tests/test_analysis/test_rgy_tool.py b/tests/test_analysis/test_rgy_tool.py index 4e5033c2..57f815cd 100644 --- a/tests/test_analysis/test_rgy_tool.py +++ b/tests/test_analysis/test_rgy_tool.py @@ -26,3 +26,4 @@ def test_rgy_average(rgy): def test_plot_rgy(rgy): output = rgy.plot_rgy() assert "Plot saved as: " in output + assert ".png" in output From 08cc4015cc6bd4d88c62fcbebeaa0bc3cbab2b5b Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Mon, 4 Nov 2024 11:39:36 -0800 Subject: [PATCH 2/3] combined rgy into 1 tool that does all - added unit test for tool --- mdagent/tools/base_tools/__init__.py | 10 +- .../base_tools/analysis_tools/__init__.py | 6 +- .../tools/base_tools/analysis_tools/rgy.py | 138 ++++++------------ mdagent/tools/maketools.py | 8 +- tests/test_analysis/test_rgy_tool.py | 31 +++- 5 files changed, 77 insertions(+), 116 deletions(-) diff --git a/mdagent/tools/base_tools/__init__.py b/mdagent/tools/base_tools/__init__.py index 23a1fd21..04c4fa35 100644 --- a/mdagent/tools/base_tools/__init__.py +++ b/mdagent/tools/base_tools/__init__.py @@ -4,11 +4,7 @@ from .analysis_tools.plot_tools import SimulationOutputFigures from .analysis_tools.ppi_tools import PPIDistance from .analysis_tools.rdf_tool import RDFTool -from .analysis_tools.rgy import ( - RadiusofGyrationAverage, - RadiusofGyrationPerFrame, - RadiusofGyrationPlot, -) +from .analysis_tools.rgy import RadiusofGyrationTool from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF from .analysis_tools.sasa import SolventAccessibleSurfaceArea from .analysis_tools.secondary_structure import ( @@ -80,9 +76,7 @@ "PCATool", "PPIDistance", "ProteinName2PDBTool", - "RadiusofGyrationAverage", - "RadiusofGyrationPerFrame", - "RadiusofGyrationPlot", + "RadiusofGyrationTool", "RDFTool", "RMSDCalculator", "Scholar2ResultLLM", diff --git a/mdagent/tools/base_tools/analysis_tools/__init__.py b/mdagent/tools/base_tools/analysis_tools/__init__.py index 81562527..26d58784 100644 --- a/mdagent/tools/base_tools/analysis_tools/__init__.py +++ b/mdagent/tools/base_tools/analysis_tools/__init__.py @@ -3,7 +3,7 @@ from .pca_tools import PCATool from .plot_tools import SimulationOutputFigures from .ppi_tools import PPIDistance -from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot +from .rgy import RadiusofGyrationTool from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF from .sasa import SolventAccessibleSurfaceArea from .vis_tools import VisFunctions, VisualizeProtein @@ -17,9 +17,7 @@ "MomentOfInertia", "PCATool", "PPIDistance", - "RadiusofGyrationAverage", - "RadiusofGyrationPerFrame", - "RadiusofGyrationPlot", + "RadiusofGyrationTool", "RMSDCalculator", "SimulationOutputFigures", "SolventAccessibleSurfaceArea", diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 9d2de240..5976d9f5 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -14,6 +14,7 @@ def __init__(self, path_registry): self.top_file = "" self.traj_file = "" self.traj = None + self.rgy_file = "" def _load_traj(self, top_file: str, traj_file: str): self.traj_file = traj_file @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str): traj_required=True, ) - def rgy_per_frame(self, force_recompute: bool = False) -> str: + def rgy_per_frame(self) -> str: rg_per_frame = md.compute_rg(self.traj) self.rgy_file = ( f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv" ) rgy_id = f"rgy_{self.traj_file}" - if rgy_id in self.path_registry.list_path_names() and force_recompute is False: - print("RGY already computed, skipping re-compute") - # todo -> maybe allow re-compute & save under different id/path - else: - np.savetxt( - self.rgy_file, - rg_per_frame, - delimiter=",", - header="Radius of Gyration (nm)", - ) - self.path_registry.map_path( - f"rgy_{self.traj_file}", - self.rgy_file, - description=f"Radii of gyration per frame for {self.traj_file}", - ) + np.savetxt( + self.rgy_file, + rg_per_frame, + delimiter=",", + header="Radius of Gyration (nm)", + ) + self.path_registry.map_path( + f"rgy_{self.traj_file}", + self.rgy_file, + description=f"Radii of gyration per frame for {self.traj_file}", + ) return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}." def rgy_average(self) -> str: - _ = self.rgy_per_frame() + if not self.rgy_file: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) avg_rg = rg_per_frame.mean() return f"Average radius of gyration: {avg_rg:.2f} nm" def plot_rgy(self) -> str: - _ = self.rgy_per_frame() + if not self.rgy_file: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) fig_analysis = f"rgy_{self.traj_file}" plot_name = self.path_registry.write_file_name( @@ -84,99 +83,44 @@ def plot_rgy(self) -> str: plt.clf() return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}" + def compute_plot_return_avg(self) -> str: + rgy_per_frame = self.rgy_per_frame() + avg_rgy = self.rgy_average() + plot_rgy = self.plot_rgy() + return rgy_per_frame + plot_rgy + avg_rgy -class RadiusofGyrationAverage(BaseTool): - name = "RadiusofGyrationAverage" - description = """This tool calculates the average radius of gyration - for a trajectory. Give this tool BOTH the trajectory file ID and the - topology file ID.""" - - path_registry: Optional[PathRegistry] - - def __init__(self, path_registry): - super().__init__() - self.path_registry = path_registry - def _run(self, traj_file: str, top_file: str) -> str: - """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" - try: - return "Succeeded. " + RGY.rgy_average() - except ValueError as e: - return f"Failed. ValueError: {e}" - except Exception as e: - return f"Failed. {type(e).__name__}: {e}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class RadiusofGyrationPerFrame(BaseTool): - name = "RadiusofGyrationPerFrame" - description = """This tool calculates the radius of gyration - at each frame of a given trajectory. +class RadiusofGyrationTool(BaseTool): + name = "RadiusofGyrationTool" + description = """This tool calculates and plots + the radius of gyration + at each frame of a given trajectory and retuns the average. Give this tool BOTH the trajectory file ID and the - topology file ID. - The tool will save the radii of gyration to a csv file and - map it to the registry.""" + topology file ID.""" path_registry: Optional[PathRegistry] + rgy: Optional[RadiusofGyration] + load_traj: bool = True - def __init__(self, path_registry): + def __init__(self, path_registry, load_traj=True): super().__init__() self.path_registry = path_registry + self.rgy = RadiusofGyration(path_registry) + self.load_traj = load_traj # only for testing def _run(self, traj_file: str, top_file: str) -> str: """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" - try: - return "Succeeded. " + RGY.rgy_per_frame() - except ValueError as e: - return f"Failed. ValueError: {e}" - except Exception as e: - return f"Failed. {type(e).__name__}: {e}" - - async def _arun(self, query: str) -> str: - """Use the tool asynchronously.""" - raise NotImplementedError("custom_search does not support async") - - -class RadiusofGyrationPlot(BaseTool): - name = "RadiusofGyrationPlot" - description = """This tool calculates the radius of gyration - at each frame of a given trajectory file and plots it. - Give this tool BOTH the trajectory file ID and the - topology file ID. - The tool will save the plot to a png file and map it to the registry.""" - - path_registry: Optional[PathRegistry] + assert self.rgy is not None, "RadiusofGyration instance is not initialized" - def __init__(self, path_registry): - super().__init__() - self.path_registry = path_registry - - def _run(self, traj_file: str, top_file: str) -> str: - """use the tool.""" - RGY = RadiusofGyration(self.path_registry) - try: - RGY._load_traj(top_file=top_file, traj_file=traj_file) - except Exception as e: - return f"Error loading traj: {e}" + if self.load_traj: + try: + self.rgy._load_traj(top_file=top_file, traj_file=traj_file) + except Exception as e: + return f"Error loading traj: {e}" try: - return "Succeeded. " + RGY.plot_rgy() - except ValueError as e: - return f"Failed. ValueError: {e}" + return "Succeeded. " + self.rgy.compute_plot_return_avg() except Exception as e: - return f"Failed. {type(e).__name__}: {e}" + return f"Failed Computing RGY: {e}" async def _arun(self, query: str) -> str: """Use the tool asynchronously.""" diff --git a/mdagent/tools/maketools.py b/mdagent/tools/maketools.py index b690d615..f831708e 100644 --- a/mdagent/tools/maketools.py +++ b/mdagent/tools/maketools.py @@ -46,9 +46,7 @@ PCATool, PPIDistance, ProteinName2PDBTool, - RadiusofGyrationAverage, - RadiusofGyrationPerFrame, - RadiusofGyrationPlot, + RadiusofGyrationTool, RDFTool, Scholar2ResultLLM, SetUpandRunFunction, @@ -99,9 +97,7 @@ def make_all_tools( PCATool(path_registry=path_instance), PPIDistance(path_registry=path_instance), ProteinName2PDBTool(path_registry=path_instance), - RadiusofGyrationAverage(path_registry=path_instance), - RadiusofGyrationPerFrame(path_registry=path_instance), - RadiusofGyrationPlot(path_registry=path_instance), + RadiusofGyrationTool(path_registry=path_instance), RDFTool(path_registry=path_instance), SetUpandRunFunction(path_registry=path_instance), SimulationOutputFigures(path_registry=path_instance), diff --git a/tests/test_analysis/test_rgy_tool.py b/tests/test_analysis/test_rgy_tool.py index 57f815cd..c131d5be 100644 --- a/tests/test_analysis/test_rgy_tool.py +++ b/tests/test_analysis/test_rgy_tool.py @@ -1,6 +1,9 @@ import pytest -from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration +from mdagent.tools.base_tools.analysis_tools.rgy import ( + RadiusofGyration, + RadiusofGyrationTool, +) @pytest.fixture @@ -13,6 +16,24 @@ def rgy(get_registry, loaded_cif_traj): return rgy +@pytest.fixture +def rgy_tool(get_registry, loaded_cif_traj): + registry = get_registry("raw", False) + rgy_tool = RadiusofGyrationTool(path_registry=registry, load_traj=False) + rgy_tool.rgy.traj = loaded_cif_traj + rgy_tool.rgy.top_file = "test_top_dummy" + rgy_tool.rgy.traj_file = "test_traj_dummy" + return rgy_tool + + +def test_rgy_tool(rgy_tool): + output = rgy_tool._run(traj_file="test_top_dummy", top_file="test_traj_dummy") + assert "Radii of gyration saved to " in output + assert "Average radius of gyration: " in output + assert "Plot saved as: " in output + assert ".png" in output + + def test_rgy_per_frame(rgy): output = rgy.rgy_per_frame() assert "Radii of gyration saved to " in output @@ -27,3 +48,11 @@ def test_plot_rgy(rgy): output = rgy.plot_rgy() assert "Plot saved as: " in output assert ".png" in output + + +def test_compute_plot_return_avg_rgy(rgy): + output = rgy.compute_plot_return_avg() + assert "Radii of gyration saved to " in output + assert "Average radius of gyration: " in output + assert "Plot saved as: " in output + assert ".png" in output From f3b02ac9ee945e22a8f0a46e75f65834cd153b6e Mon Sep 17 00:00:00 2001 From: Sam Cox Date: Mon, 4 Nov 2024 12:00:49 -0800 Subject: [PATCH 3/3] if ff mising -> return list of ff included --- .../base_tools/simulation_tools/setup_and_run.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py index c552a6b3..ca1b3c30 100644 --- a/mdagent/tools/base_tools/simulation_tools/setup_and_run.py +++ b/mdagent/tools/base_tools/simulation_tools/setup_and_run.py @@ -265,7 +265,7 @@ def setup_system(self): raise ValueError(str(e)) else: raise ValueError( - f"Error building system. Please check the forcefield files {str(e)}" + f"Error building system. Please check the forcefield files {str(e)}. Included force fields are: {FORCEFIELD_LIST}" ) if self.sys_params.get("nonbondedMethod", None) in [ @@ -1497,8 +1497,10 @@ def check_system_params(cls, values): else: for file in forcefield_files: if file not in FORCEFIELD_LIST: - error_msg += "The forcefield file is not present" - + error_msg += ( + "The forcefield file is not present: forcefield files are: " + + str(FORCEFIELD_LIST) + ) save = values.get("save", True) if not isinstance(save, bool): error_msg += "save must be a boolean value" @@ -1558,7 +1560,10 @@ def create_simulation_input(pdb_path, forcefield_files): Water_model = Forcefield_files[1] # check if they are part of the list if Forcefield not in FORCEFIELD_LIST: - raise Exception("Forcefield not recognized") + raise Exception( + "Forcefield not recognized: Possible forcefields are: " + + str(FORCEFIELD_LIST) + ) if Water_model not in FORCEFIELD_LIST: raise Exception("Water model not recognized")