diff --git a/mdagent/tools/base_tools/analysis_tools/rgy.py b/mdagent/tools/base_tools/analysis_tools/rgy.py index 9bded9dc..59872927 100644 --- a/mdagent/tools/base_tools/analysis_tools/rgy.py +++ b/mdagent/tools/base_tools/analysis_tools/rgy.py @@ -6,82 +6,59 @@ import numpy as np from langchain.tools import BaseTool -from mdagent.utils import FileType, PathRegistry +from mdagent.utils import FileType, PathRegistry, load_single_traj class RadiusofGyration: def __init__(self, path_registry): self.path_registry = path_registry - self.includes_top = [".h5", ".lh5", ".pdb"] + self.top_file = "" + self.traj_file = "" + self.traj = None + + def _load_traj(self, top_file: str, traj_file: str): + self.traj_file = traj_file + self.top_file = top_file + self.traj = load_single_traj( + path_registry=self.path_registry, + top_fileid=top_file, + traj_fileid=traj_file, + traj_required=True, + ) - def _grab_files(self, pdb_id: str) -> None: - if "_" in pdb_id: - pdb_id = pdb_id.split("_")[0] - self.pdb_id = pdb_id - all_names = self.path_registry._list_all_paths() - try: - self.pdb_path = [ - name - for name in all_names - if pdb_id in name and ".pdb" in name and "records" in name - ][0] - except IndexError: - raise ValueError(f"No pdb file found for {pdb_id}") - try: - self.dcd_path = [ - name - for name in all_names - if pdb_id in name and ".dcd" in name and "records" in name - ][0] - except IndexError: - self.dcd_path = None - pass - return None - - def _load_traj(self, pdb_id: str) -> None: - self._grab_files(pdb_id) - if self.dcd_path: - self.traj = md.load(self.dcd_path, top=self.pdb_path) - else: - self.traj = md.load(self.pdb_path) - return None - - def rad_gyration_per_frame(self, pdb_id: str) -> str: - self._load_traj(pdb_id) + def rgy_per_frame(self) -> str: rg_per_frame = md.compute_rg(self.traj) - self.rgy_file = ( - f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.pdb_id}.csv" + f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv" ) - np.savetxt( self.rgy_file, rg_per_frame, delimiter=",", header="Radius of Gyration (nm)" ) self.path_registry.map_path( - f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.pdb_id}.csv", + f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv", self.rgy_file, - description=f"Radii of gyration per frame for {self.pdb_id}", + description=f"Radii of gyration per frame for {self.traj_file}", ) return f"Radii of gyration saved to {self.rgy_file}" - def rad_gyration_average(self, pdb_id: str) -> str: - _ = self.rad_gyration_per_frame(pdb_id) + def rgy_average(self) -> str: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) avg_rg = rg_per_frame.mean() return f"Average radius of gyration: {avg_rg:.2f} nm" - def plot_rad_gyration(self, pdb_id: str) -> str: - _ = self.rad_gyration_per_frame(pdb_id) + def plot_rgy(self) -> str: + _ = self.rgy_per_frame() rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1) - fig_analysis = f"rgy_{self.pdb_id}" + fig_analysis = f"rgy_{self.traj_file}" plot_name = self.path_registry.write_file_name( type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png" ) plot_id = self.path_registry.get_fileid( file_name=plot_name, type=FileType.FIGURE ) - if plot_name.ends_with(".png"): + if plot_name.endswith(".png"): plot_name = plot_name.split(".png"[0]) plot_path = next( f"{self.path_registry.ckpt_figures}/{plot_name}_{i}.png" @@ -94,13 +71,13 @@ def plot_rad_gyration(self, pdb_id: str) -> str: plt.plot(rg_per_frame) plt.xlabel("Frame") plt.ylabel("Radius of Gyration (nm)") - plt.title(f"{pdb_id} - Radius of Gyration Over Time") + plt.title(f"{self.traj_file} - Radius of Gyration Over Time") plt.savefig(f"{plot_path}") self.path_registry.map_path( plot_id, plot_path, - description=f"Plot of radii of gyration over time for {self.pdb_id}", + description=f"Plot of radii of gyration over time for {self.traj_file}", ) plt.close() plt.clf() @@ -110,8 +87,8 @@ def plot_rad_gyration(self, pdb_id: str) -> str: class RadiusofGyrationAverage(BaseTool): name = "RadiusofGyrationAverage" description = """This tool calculates the average radius of gyration - for the given trajectory file. Give this tool the - protein ID (PDB ID) only. The tool will automatically find the necessary files.""" + for a trajectory. Give this tool BOTH the trajectory file ID and the + topology file ID.""" path_registry: Optional[PathRegistry] @@ -119,11 +96,15 @@ def __init__(self, path_registry): super().__init__() self.path_registry = path_registry - def _run(self, pdb_id: str) -> str: + def _run(self, traj_file: str, top_file: str) -> str: """use the tool.""" + RGY = RadiusofGyration(self.path_registry) try: - RGY = RadiusofGyration(self.path_registry) - return "Succeeded. " + RGY.rad_gyration_average(pdb_id) + RGY._load_traj(top_file=top_file, traj_file=traj_file) + except Exception as e: + return f"Error loading traj: {e}" + try: + return "Succeeded. " + RGY.rgy_average() except ValueError as e: return f"Failed. ValueError: {e}" except Exception as e: @@ -137,8 +118,9 @@ async def _arun(self, query: str) -> str: class RadiusofGyrationPerFrame(BaseTool): name = "RadiusofGyrationPerFrame" description = """This tool calculates the radius of gyration - at each frame of a given trajectory file. Give this tool the - protein ID (PDB ID) only. The tool will automatically find the necessary files. + at each frame of a given trajectory. + Give this tool BOTH the trajectory file ID and the + topology file ID. The tool will save the radii of gyration to a csv file and map it to the registry.""" @@ -148,11 +130,15 @@ def __init__(self, path_registry): super().__init__() self.path_registry = path_registry - def _run(self, pdb_id: str) -> str: + def _run(self, traj_file: str, top_file: str) -> str: """use the tool.""" + RGY = RadiusofGyration(self.path_registry) + try: + RGY._load_traj(top_file=top_file, traj_file=traj_file) + except Exception as e: + return f"Error loading traj: {e}" try: - RGY = RadiusofGyration(self.path_registry) - return "Succeeded. " + RGY.rad_gyration_per_frame(pdb_id) + return "Succeeded. " + RGY.rgy_per_frame() except ValueError as e: return f"Failed. ValueError: {e}" except Exception as e: @@ -167,8 +153,8 @@ class RadiusofGyrationPlot(BaseTool): name = "RadiusofGyrationPlot" description = """This tool calculates the radius of gyration at each frame of a given trajectory file and plots it. - Give this tool the protein ID (PDB ID) only. - The tool will automatically find the necessary files. + Give this tool BOTH the trajectory file ID and the + topology file ID. The tool will save the plot to a png file and map it to the registry.""" path_registry: Optional[PathRegistry] @@ -177,11 +163,15 @@ def __init__(self, path_registry): super().__init__() self.path_registry = path_registry - def _run(self, pdb_id: str) -> str: + def _run(self, traj_file: str, top_file: str) -> str: """use the tool.""" + RGY = RadiusofGyration(self.path_registry) + try: + RGY._load_traj(top_file=top_file, traj_file=traj_file) + except Exception as e: + return f"Error loading traj: {e}" try: - RGY = RadiusofGyration(self.path_registry) - return "Succeeded. " + RGY.plot_rad_gyration(pdb_id) + return "Succeeded. " + RGY.plot_rgy() except ValueError as e: return f"Failed. ValueError: {e}" except Exception as e: diff --git a/tests/test_analysis/test_rgy_tool.py b/tests/test_analysis/test_rgy_tool.py new file mode 100644 index 00000000..4e5033c2 --- /dev/null +++ b/tests/test_analysis/test_rgy_tool.py @@ -0,0 +1,28 @@ +import pytest + +from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration + + +@pytest.fixture +def rgy(get_registry, loaded_cif_traj): + registry = get_registry("raw", False) + rgy = RadiusofGyration(path_registry=registry) + rgy.traj = loaded_cif_traj + rgy.top_file = "test_top_dummy" + rgy.traj_file = "test_traj_dummy" + return rgy + + +def test_rgy_per_frame(rgy): + output = rgy.rgy_per_frame() + assert "Radii of gyration saved to " in output + + +def test_rgy_average(rgy): + output = rgy.rgy_average() + assert "Average radius of gyration: " in output + + +def test_plot_rgy(rgy): + output = rgy.plot_rgy() + assert "Plot saved as: " in output