Skip to content

Commit

Permalink
adjusting rgy to take traj & wrote missing unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 committed Aug 8, 2024
1 parent 18f2a13 commit 9b4a0b9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 64 deletions.
118 changes: 54 additions & 64 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,82 +6,59 @@
import numpy as np
from langchain.tools import BaseTool

from mdagent.utils import FileType, PathRegistry
from mdagent.utils import FileType, PathRegistry, load_single_traj


class RadiusofGyration:
def __init__(self, path_registry):
self.path_registry = path_registry
self.includes_top = [".h5", ".lh5", ".pdb"]
self.top_file = ""
self.traj_file = ""
self.traj = None

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
self.top_file = top_file
self.traj = load_single_traj(
path_registry=self.path_registry,
top_fileid=top_file,
traj_fileid=traj_file,
traj_required=True,
)

def _grab_files(self, pdb_id: str) -> None:
if "_" in pdb_id:
pdb_id = pdb_id.split("_")[0]
self.pdb_id = pdb_id
all_names = self.path_registry._list_all_paths()
try:
self.pdb_path = [
name
for name in all_names
if pdb_id in name and ".pdb" in name and "records" in name
][0]
except IndexError:
raise ValueError(f"No pdb file found for {pdb_id}")
try:
self.dcd_path = [
name
for name in all_names
if pdb_id in name and ".dcd" in name and "records" in name
][0]
except IndexError:
self.dcd_path = None
pass
return None

def _load_traj(self, pdb_id: str) -> None:
self._grab_files(pdb_id)
if self.dcd_path:
self.traj = md.load(self.dcd_path, top=self.pdb_path)
else:
self.traj = md.load(self.pdb_path)
return None

def rad_gyration_per_frame(self, pdb_id: str) -> str:
self._load_traj(pdb_id)
def rgy_per_frame(self) -> str:
rg_per_frame = md.compute_rg(self.traj)

self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.pdb_id}.csv"
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)

np.savetxt(
self.rgy_file, rg_per_frame, delimiter=",", header="Radius of Gyration (nm)"
)
self.path_registry.map_path(
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.pdb_id}.csv",
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv",
self.rgy_file,
description=f"Radii of gyration per frame for {self.pdb_id}",
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file}"

def rad_gyration_average(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
def rgy_average(self) -> str:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rad_gyration(self, pdb_id: str) -> str:
_ = self.rad_gyration_per_frame(pdb_id)
def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.pdb_id}"
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
type=FileType.FIGURE, fig_analysis=fig_analysis, file_format="png"
)
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if plot_name.ends_with(".png"):
if plot_name.endswith(".png"):
plot_name = plot_name.split(".png"[0])
plot_path = next(
f"{self.path_registry.ckpt_figures}/{plot_name}_{i}.png"
Expand All @@ -94,13 +71,13 @@ def plot_rad_gyration(self, pdb_id: str) -> str:
plt.plot(rg_per_frame)
plt.xlabel("Frame")
plt.ylabel("Radius of Gyration (nm)")
plt.title(f"{pdb_id} - Radius of Gyration Over Time")
plt.title(f"{self.traj_file} - Radius of Gyration Over Time")

plt.savefig(f"{plot_path}")
self.path_registry.map_path(
plot_id,
plot_path,
description=f"Plot of radii of gyration over time for {self.pdb_id}",
description=f"Plot of radii of gyration over time for {self.traj_file}",
)
plt.close()
plt.clf()
Expand All @@ -110,20 +87,24 @@ def plot_rad_gyration(self, pdb_id: str) -> str:
class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for the given trajectory file. Give this tool the
protein ID (PDB ID) only. The tool will automatically find the necessary files."""
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""

path_registry: Optional[PathRegistry]

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY = RadiusofGyration(self.path_registry)
return "Succeeded. " + RGY.rad_gyration_average(pdb_id)
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
Expand All @@ -137,8 +118,9 @@ async def _arun(self, query: str) -> str:
class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file. Give this tool the
protein ID (PDB ID) only. The tool will automatically find the necessary files.
at each frame of a given trajectory.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""

Expand All @@ -148,11 +130,15 @@ def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
RGY = RadiusofGyration(self.path_registry)
return "Succeeded. " + RGY.rad_gyration_per_frame(pdb_id)
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
Expand All @@ -167,8 +153,8 @@ class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool the protein ID (PDB ID) only.
The tool will automatically find the necessary files.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
Expand All @@ -177,11 +163,15 @@ def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_id: str) -> str:
def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
RGY = RadiusofGyration(self.path_registry)
return "Succeeded. " + RGY.plot_rad_gyration(pdb_id)
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_analysis/test_rgy_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest

from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration


@pytest.fixture
def rgy(get_registry, loaded_cif_traj):
registry = get_registry("raw", False)
rgy = RadiusofGyration(path_registry=registry)
rgy.traj = loaded_cif_traj
rgy.top_file = "test_top_dummy"
rgy.traj_file = "test_traj_dummy"
return rgy


def test_rgy_per_frame(rgy):
output = rgy.rgy_per_frame()
assert "Radii of gyration saved to " in output


def test_rgy_average(rgy):
output = rgy.rgy_average()
assert "Average radius of gyration: " in output


def test_plot_rgy(rgy):
output = rgy.plot_rgy()
assert "Plot saved as: " in output

0 comments on commit 9b4a0b9

Please sign in to comment.