Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
agent can input 'first' 'last' or 'all' for dssp now & adjusted descr…
Browse files Browse the repository at this point in the history
…iption
SamCox822 committed Aug 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent a7dd1d7 commit 18f2a13
Showing 2 changed files with 59 additions and 9 deletions.
42 changes: 33 additions & 9 deletions mdagent/tools/base_tools/analysis_tools/secondary_structure.py
Original file line number Diff line number Diff line change
@@ -43,21 +43,21 @@ def write_raw_x(
class ComputeDSSP(BaseTool):
name = "ComputeDSSP"
description = """Compute the DSSP (secondary structure) assignment
for a protein trajectory. Input is a trajectory file ID
for a protein trajectory. Input is a trajectory file ID and
a target_frames, which can be "first", "last", or "all",
and an optional topology file ID.
Input "first" to get DSSP of only the first frame.
Input "last" to get DSSP of only the last frame.
Input "all" to get DSSP of all frames in trajectory, combined.
The output is an array with the DSSP code for each
residue at each time point."""
path_registry: PathRegistry = PathRegistry.get_instance()
simplified: bool = True
last_only: bool = True

def __init__(
self, path_registry: PathRegistry, simplified: bool = True, last_only=True
):
def __init__(self, path_registry: PathRegistry, simplified: bool = True):
super().__init__()
self.path_registry = path_registry
self.simplified = simplified
self.last_only = last_only

def _dssp_codes(self) -> list[str]:
"""
@@ -144,7 +144,32 @@ def _compute_dssp(self, traj: md.Trajectory) -> np.ndarray:
"""
return md.compute_dssp(traj, simplified=self.simplified)

def _run(self, traj_file: str, top_file: Optional[str] = None) -> str:
def _get_frame(self, traj, target_frames):
"""
Retrieves the target frame(s) of the trajectory for DSSP.
Args:
traj: the trajectory
target_frames: the target frames to select. can be first, last, or all
Returns:
the trajectory with only target frames"""

if target_frames.lower().strip() == "all":
return traj
if target_frames.lower().strip() == "first":
return traj[0]
if target_frames.lower().strip() == "last":
return traj[-1]
else:
raise ValueError("Target Frames must be 'all', 'first', or 'last'.")

def _run(
self,
traj_file: str,
top_file: Optional[str] = None,
target_frames: str = "last",
) -> str:
"""
Computes the DSSP assignments for a trajectory and saves the results
to a file.
@@ -164,8 +189,7 @@ def _run(self, traj_file: str, top_file: Optional[str] = None) -> str:
)
if not traj:
raise Exception("Trajectory could not be loaded.")
if self.last_only and len(traj) > 1:
traj = traj[-1]
traj = self._get_frame(traj, target_frames)
except Exception as e:
print("Error loading trajectory: ", e)
return str(e)
26 changes: 26 additions & 0 deletions tests/test_analysis/test_secondary_structure.py
Original file line number Diff line number Diff line change
@@ -37,6 +37,32 @@ def test_compute_dssp(loaded_cif_traj, compute_dssp_simple, compute_dssp):
assert np.all(dssp[0][:10] == [" ", " ", " ", "E", "E", "E", "T", "T", "E", "E"])


def test_get_frame(compute_dssp):
# random dummy traj with 3 frames
xyz = np.random.rand(10, 10, 3)
topology = md.Topology()
chain = topology.add_chain()
residue = topology.add_residue("ALA", chain)
for _ in range(10):
topology.add_atom("CA", md.element.carbon, residue)
traj = md.Trajectory(xyz, topology)

# first frame
first_frame = compute_dssp._get_frame(traj, "first")
assert first_frame.n_frames == 1
assert np.array_equal(first_frame.xyz, traj.xyz[0].reshape(1, -1, 3))

# last frame
last_frame = compute_dssp._get_frame(traj, "last")
assert last_frame.n_frames == 1
assert np.array_equal(last_frame.xyz, traj.xyz[-1].reshape(1, -1, 3))

# all frames
all_frames = compute_dssp._get_frame(traj, "all")
assert all_frames.n_frames == traj.n_frames
assert np.array_equal(all_frames.xyz, traj.xyz)


def test_dssp_codes(compute_dssp_simple, compute_dssp):
dssp_codes_simple = compute_dssp_simple._dssp_codes()
assert dssp_codes_simple == ["H", "E", "C", "NA"]

0 comments on commit 18f2a13

Please sign in to comment.