Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

more various fixes (rgy) #162

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rdf_tool import RDFTool
from .analysis_tools.rgy import (
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
)
from .analysis_tools.rgy import RadiusofGyrationTool
from .analysis_tools.rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .analysis_tools.sasa import SolventAccessibleSurfaceArea
from .analysis_tools.secondary_structure import (
Expand Down Expand Up @@ -80,9 +76,7 @@
"PCATool",
"PPIDistance",
"ProteinName2PDBTool",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RDFTool",
"RMSDCalculator",
"Scholar2ResultLLM",
Expand Down
6 changes: 2 additions & 4 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .pca_tools import PCATool
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rgy import RadiusofGyrationAverage, RadiusofGyrationPerFrame, RadiusofGyrationPlot
from .rgy import RadiusofGyrationTool
from .rmsd_tools import ComputeLPRMSD, ComputeRMSD, ComputeRMSF
from .sasa import SolventAccessibleSurfaceArea
from .vis_tools import VisFunctions, VisualizeProtein
Expand All @@ -17,9 +17,7 @@
"MomentOfInertia",
"PCATool",
"PPIDistance",
"RadiusofGyrationAverage",
"RadiusofGyrationPerFrame",
"RadiusofGyrationPlot",
"RadiusofGyrationTool",
"RMSDCalculator",
"SimulationOutputFigures",
"SolventAccessibleSurfaceArea",
Expand Down
145 changes: 44 additions & 101 deletions mdagent/tools/base_tools/analysis_tools/rgy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, path_registry):
self.top_file = ""
self.traj_file = ""
self.traj = None
self.rgy_file = ""

def _load_traj(self, top_file: str, traj_file: str):
self.traj_file = traj_file
Expand All @@ -25,38 +26,36 @@ def _load_traj(self, top_file: str, traj_file: str):
traj_required=True,
)

def rgy_per_frame(self, force_recompute: bool = False) -> str:
def rgy_per_frame(self) -> str:
rg_per_frame = md.compute_rg(self.traj)
self.rgy_file = (
f"{self.path_registry.ckpt_figures}/radii_of_gyration_{self.traj_file}.csv"
)
rgy_id = f"rgy_{self.traj_file}"
if rgy_id in self.path_registry.list_path_names() and force_recompute is False:
print("RGY already computed, skipping re-compute")
# todo -> maybe allow re-compute & save under different id/path
else:
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
np.savetxt(
self.rgy_file,
rg_per_frame,
delimiter=",",
header="Radius of Gyration (nm)",
)
self.path_registry.map_path(
f"rgy_{self.traj_file}",
self.rgy_file,
description=f"Radii of gyration per frame for {self.traj_file}",
)
return f"Radii of gyration saved to {self.rgy_file} with id {rgy_id}."

def rgy_average(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
avg_rg = rg_per_frame.mean()

return f"Average radius of gyration: {avg_rg:.2f} nm"

def plot_rgy(self) -> str:
_ = self.rgy_per_frame()
if not self.rgy_file:
_ = self.rgy_per_frame()
rg_per_frame = np.loadtxt(self.rgy_file, delimiter=",", skiprows=1)
fig_analysis = f"rgy_{self.traj_file}"
plot_name = self.path_registry.write_file_name(
Expand All @@ -66,9 +65,8 @@ def plot_rgy(self) -> str:
plot_id = self.path_registry.get_fileid(
file_name=plot_name, type=FileType.FIGURE
)
if plot_name.endswith(".png"):
plot_name = plot_name.split(".png")[0]
plot_path = f"{self.path_registry.ckpt_figures}/{plot_name}"
plot_path = plot_path if plot_path.endswith(".png") else plot_path + ".png"
print("plot_path", plot_path)
plt.plot(rg_per_frame)
plt.xlabel("Frame")
Expand All @@ -78,106 +76,51 @@ def plot_rgy(self) -> str:
plt.savefig(f"{plot_path}")
self.path_registry.map_path(
plot_id,
plot_path + ".png",
plot_path,
description=f"Plot of radii of gyration over time for {self.traj_file}",
)
plt.close()
plt.clf()
return "Plot saved as: " + f"{plot_name}.png with plot ID {plot_id}"


class RadiusofGyrationAverage(BaseTool):
name = "RadiusofGyrationAverage"
description = """This tool calculates the average radius of gyration
for a trajectory. Give this tool BOTH the trajectory file ID and the
topology file ID."""
return "Plot saved as: " + f"{plot_name} with plot ID {plot_id}"

path_registry: Optional[PathRegistry]
def compute_plot_return_avg(self) -> str:
rgy_per_frame = self.rgy_per_frame()
avg_rgy = self.rgy_average()
plot_rgy = self.plot_rgy()
return rgy_per_frame + plot_rgy + avg_rgy

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_average()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPerFrame(BaseTool):
name = "RadiusofGyrationPerFrame"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory.
class RadiusofGyrationTool(BaseTool):
name = "RadiusofGyrationTool"
description = """This tool calculates and plots
the radius of gyration
at each frame of a given trajectory and retuns the average.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the radii of gyration to a csv file and
map it to the registry."""
topology file ID."""

path_registry: Optional[PathRegistry]
rgy: Optional[RadiusofGyration]
load_traj: bool = True

def __init__(self, path_registry):
def __init__(self, path_registry, load_traj=True):
super().__init__()
self.path_registry = path_registry
self.rgy = RadiusofGyration(path_registry)
self.load_traj = load_traj # only for testing

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.rgy_per_frame()
except ValueError as e:
return f"Failed. ValueError: {e}"
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
raise NotImplementedError("custom_search does not support async")


class RadiusofGyrationPlot(BaseTool):
name = "RadiusofGyrationPlot"
description = """This tool calculates the radius of gyration
at each frame of a given trajectory file and plots it.
Give this tool BOTH the trajectory file ID and the
topology file ID.
The tool will save the plot to a png file and map it to the registry."""

path_registry: Optional[PathRegistry]
assert self.rgy is not None, "RadiusofGyration instance is not initialized"

def __init__(self, path_registry):
super().__init__()
self.path_registry = path_registry

def _run(self, traj_file: str, top_file: str) -> str:
"""use the tool."""
RGY = RadiusofGyration(self.path_registry)
try:
RGY._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
if self.load_traj:
try:
self.rgy._load_traj(top_file=top_file, traj_file=traj_file)
except Exception as e:
return f"Error loading traj: {e}"
try:
return "Succeeded. " + RGY.plot_rgy()
except ValueError as e:
return f"Failed. ValueError: {e}"
return "Succeeded. " + self.rgy.compute_plot_return_avg()
except Exception as e:
return f"Failed. {type(e).__name__}: {e}"
return f"Failed Computing RGY: {e}"

async def _arun(self, query: str) -> str:
"""Use the tool asynchronously."""
Expand Down
13 changes: 9 additions & 4 deletions mdagent/tools/base_tools/simulation_tools/setup_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def setup_system(self):
raise ValueError(str(e))
else:
raise ValueError(
f"Error building system. Please check the forcefield files {str(e)}"
f"Error building system. Please check the forcefield files {str(e)}. Included force fields are: {FORCEFIELD_LIST}"
)

if self.sys_params.get("nonbondedMethod", None) in [
Expand Down Expand Up @@ -1497,8 +1497,10 @@ def check_system_params(cls, values):
else:
for file in forcefield_files:
if file not in FORCEFIELD_LIST:
error_msg += "The forcefield file is not present"

error_msg += (
"The forcefield file is not present: forcefield files are: "
+ str(FORCEFIELD_LIST)
)
save = values.get("save", True)
if not isinstance(save, bool):
error_msg += "save must be a boolean value"
Expand Down Expand Up @@ -1558,7 +1560,10 @@ def create_simulation_input(pdb_path, forcefield_files):
Water_model = Forcefield_files[1]
# check if they are part of the list
if Forcefield not in FORCEFIELD_LIST:
raise Exception("Forcefield not recognized")
raise Exception(
"Forcefield not recognized: Possible forcefields are: "
+ str(FORCEFIELD_LIST)
)
if Water_model not in FORCEFIELD_LIST:
raise Exception("Water model not recognized")

Expand Down
8 changes: 2 additions & 6 deletions mdagent/tools/maketools.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,7 @@
PCATool,
PPIDistance,
ProteinName2PDBTool,
RadiusofGyrationAverage,
RadiusofGyrationPerFrame,
RadiusofGyrationPlot,
RadiusofGyrationTool,
RDFTool,
Scholar2ResultLLM,
SetUpandRunFunction,
Expand Down Expand Up @@ -99,9 +97,7 @@ def make_all_tools(
PCATool(path_registry=path_instance),
PPIDistance(path_registry=path_instance),
ProteinName2PDBTool(path_registry=path_instance),
RadiusofGyrationAverage(path_registry=path_instance),
RadiusofGyrationPerFrame(path_registry=path_instance),
RadiusofGyrationPlot(path_registry=path_instance),
RadiusofGyrationTool(path_registry=path_instance),
RDFTool(path_registry=path_instance),
SetUpandRunFunction(path_registry=path_instance),
SimulationOutputFigures(path_registry=path_instance),
Expand Down
32 changes: 31 additions & 1 deletion tests/test_analysis/test_rgy_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pytest

from mdagent.tools.base_tools.analysis_tools.rgy import RadiusofGyration
from mdagent.tools.base_tools.analysis_tools.rgy import (
RadiusofGyration,
RadiusofGyrationTool,
)


@pytest.fixture
Expand All @@ -13,6 +16,24 @@ def rgy(get_registry, loaded_cif_traj):
return rgy


@pytest.fixture
def rgy_tool(get_registry, loaded_cif_traj):
registry = get_registry("raw", False)
rgy_tool = RadiusofGyrationTool(path_registry=registry, load_traj=False)
rgy_tool.rgy.traj = loaded_cif_traj
rgy_tool.rgy.top_file = "test_top_dummy"
rgy_tool.rgy.traj_file = "test_traj_dummy"
return rgy_tool


def test_rgy_tool(rgy_tool):
output = rgy_tool._run(traj_file="test_top_dummy", top_file="test_traj_dummy")
assert "Radii of gyration saved to " in output
assert "Average radius of gyration: " in output
assert "Plot saved as: " in output
assert ".png" in output


def test_rgy_per_frame(rgy):
output = rgy.rgy_per_frame()
assert "Radii of gyration saved to " in output
Expand All @@ -26,3 +47,12 @@ def test_rgy_average(rgy):
def test_plot_rgy(rgy):
output = rgy.plot_rgy()
assert "Plot saved as: " in output
assert ".png" in output


def test_compute_plot_return_avg_rgy(rgy):
output = rgy.compute_plot_return_avg()
assert "Radii of gyration saved to " in output
assert "Average radius of gyration: " in output
assert "Plot saved as: " in output
assert ".png" in output
Loading