From 18f2a134d956256ec3fc502b4624657f8f1d9a53 Mon Sep 17 00:00:00 2001 From: Samantha Cox Date: Wed, 7 Aug 2024 21:00:07 -0700 Subject: [PATCH] agent can input 'first' 'last' or 'all' for dssp now & adjusted description --- .../analysis_tools/secondary_structure.py | 42 +++++++++++++++---- .../test_analysis/test_secondary_structure.py | 26 ++++++++++++ 2 files changed, 59 insertions(+), 9 deletions(-) diff --git a/mdagent/tools/base_tools/analysis_tools/secondary_structure.py b/mdagent/tools/base_tools/analysis_tools/secondary_structure.py index c558bf1f..35c0da67 100644 --- a/mdagent/tools/base_tools/analysis_tools/secondary_structure.py +++ b/mdagent/tools/base_tools/analysis_tools/secondary_structure.py @@ -43,21 +43,21 @@ def write_raw_x( class ComputeDSSP(BaseTool): name = "ComputeDSSP" description = """Compute the DSSP (secondary structure) assignment - for a protein trajectory. Input is a trajectory file ID + for a protein trajectory. Input is a trajectory file ID and + a target_frames, which can be "first", "last", or "all", and an optional topology file ID. + Input "first" to get DSSP of only the first frame. + Input "last" to get DSSP of only the last frame. + Input "all" to get DSSP of all frames in trajectory, combined. The output is an array with the DSSP code for each residue at each time point.""" path_registry: PathRegistry = PathRegistry.get_instance() simplified: bool = True - last_only: bool = True - def __init__( - self, path_registry: PathRegistry, simplified: bool = True, last_only=True - ): + def __init__(self, path_registry: PathRegistry, simplified: bool = True): super().__init__() self.path_registry = path_registry self.simplified = simplified - self.last_only = last_only def _dssp_codes(self) -> list[str]: """ @@ -144,7 +144,32 @@ def _compute_dssp(self, traj: md.Trajectory) -> np.ndarray: """ return md.compute_dssp(traj, simplified=self.simplified) - def _run(self, traj_file: str, top_file: Optional[str] = None) -> str: + def _get_frame(self, traj, target_frames): + """ + Retrieves the target frame(s) of the trajectory for DSSP. + + Args: + traj: the trajectory + target_frames: the target frames to select. can be first, last, or all + + Returns: + the trajectory with only target frames""" + + if target_frames.lower().strip() == "all": + return traj + if target_frames.lower().strip() == "first": + return traj[0] + if target_frames.lower().strip() == "last": + return traj[-1] + else: + raise ValueError("Target Frames must be 'all', 'first', or 'last'.") + + def _run( + self, + traj_file: str, + top_file: Optional[str] = None, + target_frames: str = "last", + ) -> str: """ Computes the DSSP assignments for a trajectory and saves the results to a file. @@ -164,8 +189,7 @@ def _run(self, traj_file: str, top_file: Optional[str] = None) -> str: ) if not traj: raise Exception("Trajectory could not be loaded.") - if self.last_only and len(traj) > 1: - traj = traj[-1] + traj = self._get_frame(traj, target_frames) except Exception as e: print("Error loading trajectory: ", e) return str(e) diff --git a/tests/test_analysis/test_secondary_structure.py b/tests/test_analysis/test_secondary_structure.py index 497c65c2..699e5aef 100644 --- a/tests/test_analysis/test_secondary_structure.py +++ b/tests/test_analysis/test_secondary_structure.py @@ -37,6 +37,32 @@ def test_compute_dssp(loaded_cif_traj, compute_dssp_simple, compute_dssp): assert np.all(dssp[0][:10] == [" ", " ", " ", "E", "E", "E", "T", "T", "E", "E"]) +def test_get_frame(compute_dssp): + # random dummy traj with 3 frames + xyz = np.random.rand(10, 10, 3) + topology = md.Topology() + chain = topology.add_chain() + residue = topology.add_residue("ALA", chain) + for _ in range(10): + topology.add_atom("CA", md.element.carbon, residue) + traj = md.Trajectory(xyz, topology) + + # first frame + first_frame = compute_dssp._get_frame(traj, "first") + assert first_frame.n_frames == 1 + assert np.array_equal(first_frame.xyz, traj.xyz[0].reshape(1, -1, 3)) + + # last frame + last_frame = compute_dssp._get_frame(traj, "last") + assert last_frame.n_frames == 1 + assert np.array_equal(last_frame.xyz, traj.xyz[-1].reshape(1, -1, 3)) + + # all frames + all_frames = compute_dssp._get_frame(traj, "all") + assert all_frames.n_frames == traj.n_frames + assert np.array_equal(all_frames.xyz, traj.xyz) + + def test_dssp_codes(compute_dssp_simple, compute_dssp): dssp_codes_simple = compute_dssp_simple._dssp_codes() assert dssp_codes_simple == ["H", "E", "C", "NA"]