Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Registry #94

Merged
merged 16 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions mdagent/tools/base_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
from .analysis_tools.plot_tools import SimulationOutputFigures
from .analysis_tools.ppi_tools import PPIDistance
from .analysis_tools.rmsd_tools import RMSDCalculator
from .analysis_tools.vis_tools import (
CheckDirectoryFiles,
VisFunctions,
VisualizeProtein,
)
from .analysis_tools.vis_tools import VisFunctions, VisualizeProtein
from .preprocess_tools.clean_tools import (
AddHydrogensCleaningTool,
CleaningToolFunction,
CleaningTools,
RemoveWaterCleaningTool,
SpecializedCleanTool,
)
from .preprocess_tools.pdb_tools import (
PackMolTool,
ProteinName2PDBTool,
SmallMolPDB,
get_pdb,
)
from .preprocess_tools.packing import PackMolTool
from .preprocess_tools.pdb_get import ProteinName2PDBTool, SmallMolPDB, get_pdb
from .simulation_tools.create_simulation import ModifyBaseSimulationScriptTool
from .simulation_tools.setup_and_run import (
InstructionSummary,
SetUpandRunFunction,
SetUpAndRunTool,
SimulationFunctions,
Expand All @@ -32,9 +23,7 @@

__all__ = [
"AddHydrogensCleaningTool",
"CheckDirectoryFiles",
"CleaningTools",
"InstructionSummary",
"ListRegistryPaths",
"MapPath2Name",
"ProteinName2PDBTool",
Expand Down
3 changes: 1 addition & 2 deletions mdagent/tools/base_tools/analysis_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from .plot_tools import SimulationOutputFigures
from .ppi_tools import PPIDistance
from .rmsd_tools import RMSDCalculator
from .vis_tools import CheckDirectoryFiles, VisFunctions, VisualizeProtein
from .vis_tools import VisFunctions, VisualizeProtein

__all__ = [
"PPIDistance",
"RMSDCalculator",
"SimulationOutputFigures",
"CheckDirectoryFiles",
"VisualizeProtein",
"VisFunctions",
]
157 changes: 94 additions & 63 deletions mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,60 +8,88 @@
from mdagent.utils import PathRegistry


def process_csv(file_name):
with open(file_name, "r") as f:
reader = csv.DictReader(f)
headers = reader.fieldnames
data = list(reader)

matched_headers = [
(i, header)
for i, header in enumerate(headers)
if re.search(r"(step|time)", header, re.IGNORECASE)
]

return data, headers, matched_headers


def plot_data(data, headers, matched_headers):
# Get the first matched header
if matched_headers:
time_or_step = matched_headers[0][1]
xlab = "step" if "step" in time_or_step.lower() else "time"
else:
print("No 'step' or 'time' headers found.")
return

failed_headers = []
created_plots = []
for header in headers:
if header != time_or_step:
try:
x = [float(row[time_or_step]) for row in data]
y = [float(row[header]) for row in data]

header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{xlab} vs {header_lab}")
plt.savefig(plot_name)
plt.close()

created_plots.append(plot_name)
except ValueError:
failed_headers.append(header)

if len(failed_headers) == len(headers) - 1: # -1 to account for time_or_step header
raise Exception("All plots failed due to non-numeric data.")

return ", ".join(created_plots)
class PlottingTools:
def __init__(
self,
path_registry,
):
self.path_registry = path_registry
self.data = None
self.headers = None
self.matched_headers = None
self.file_id = None
self.file_path = None

def _find_file(self, file_id: str) -> None:
self.file_id = file_id
self.file_path = self.path_registry.get_mapped_path(file_id)
if not self.file_path:
raise FileNotFoundError("File not found.")
return None

def process_csv(self) -> None:
with open(self.file_path, "r") as f:
reader = csv.DictReader(f)
self.headers = reader.fieldnames if reader.fieldnames is not None else []
self.data = list(reader)

self.matched_headers = [
(i, header)
for i, header in enumerate(self.headers)
if re.search(r"(step|time)", header, re.IGNORECASE)
]

if not self.matched_headers or not self.headers or not self.data:
raise ValueError("File could not be processed.")
return None

def plot_data(self) -> str:
if self.matched_headers:
time_or_step = self.matched_headers[0][1]
xlab = "step" if "step" in time_or_step.lower() else "time"
else:
raise ValueError("No timestep found.")

failed_headers = []
created_plots = []
for header in self.headers:
if header != time_or_step:
try:
x = [float(row[time_or_step]) for row in self.data]
y = [float(row[header]) for row in self.data]

header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{self.file_id}_{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{self.file_id}_{xlab} vs {header_lab}")
plt.savefig(plot_name)
self.path_registry.map_path(
plot_name,
plot_name,
(
f"Post Simulation Figure for {self.file_id}"
f" - {header_lab} vs {xlab}"
),
)
plt.close()

created_plots.append(plot_name)
except ValueError:
failed_headers.append(header)

if (
len(failed_headers) == len(self.headers) - 1
): # -1 to account for time_or_step header
raise Exception("All plots failed due to non-numeric data.")

return ", ".join(created_plots)


class SimulationOutputFigures(BaseTool):
Expand All @@ -71,24 +99,27 @@ class SimulationOutputFigures(BaseTool):
simulation and create figures for
all physical parameters
versus timestep of the simulation.
Give this tool the path to the
csv file output from the simulation."""
Give this tool the name of the
csv file output from the simulation.
The tool will get the exact path."""

path_registry: Optional[PathRegistry]

def _run(self, file_path: str) -> str:
def __init__(self, path_registry: Optional[PathRegistry] = None):
super().__init__()
self.path_registry = path_registry

def _run(self, file_id: str) -> str:
"""use the tool."""
try:
data, headers, matched_headers = process_csv(file_path)
plot_result = plot_data(data, headers, matched_headers)
plotting_tools = PlottingTools(self.path_registry)
plotting_tools._find_file(file_id)
plotting_tools.process_csv()
plot_result = plotting_tools.plot_data()
if type(plot_result) == str:
return "Figures created: " + plot_result
else:
return "No figures created."
except ValueError:
return "No timestep data found in csv file."
except FileNotFoundError:
return "Issue with CSV file, file not found."
except Exception as e:
return str(e)

Expand Down
23 changes: 18 additions & 5 deletions mdagent/tools/base_tools/analysis_tools/ppi_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from langchain.tools import BaseTool
from pydantic import BaseModel, Field

from mdagent.utils import PathRegistry

def ppi_distance(pdb_file, binding_site="protein"):

def ppi_distance(file_path, binding_site="protein"):
"""
Calculates minimum heavy-atom distance between peptide (assumed to be
smallest chain) and protein. Returns average distance between these two.
Expand All @@ -16,7 +18,7 @@ def ppi_distance(pdb_file, binding_site="protein"):
Can work with any protein-protein interaction (PPI)
"""
# load and find smallest chain
u = mda.Universe(pdb_file)
u = mda.Universe(file_path)
peptide = None
for chain in u.segments:
if peptide is None or len(chain.residues) < len(peptide):
Expand Down Expand Up @@ -49,14 +51,25 @@ class PPIDistance(BaseTool):
name: str = "ppi_distance"
description: str = """Useful for calculating minimum heavy-atom distance
between peptide and protein. First, make sure you have valid PDB file with
any protein-protein interaction."""
any protein-protein interaction. Give this tool the name of the file. The
tool will find the path."""
args_schema: Type[BaseModel] = PPIDistanceInputSchema
path_registry: Optional[PathRegistry]

def __init__(self, path_registry: Optional[PathRegistry]):
super().__init__()
self.path_registry = path_registry

def _run(self, pdb_file: str, binding_site: str = "protein"):
if not pdb_file.endswith(".pdb"):
if not self.path_registry:
return "Error: Path registry is not set" # this should not happen
file_path = self.path_registry.get_mapped_path(pdb_file)
if not file_path:
return f"File not found: {pdb_file}"
if not file_path.endswith(".pdb"):
return "Error with input: PDB file must have .pdb extension"
try:
avg_dist = ppi_distance(pdb_file, binding_site=binding_site)
avg_dist = ppi_distance(file_path, binding_site=binding_site)
except ValueError as e:
return (
f"ValueError: {e}. \nMake sure to provide valid PBD "
Expand Down
Loading
Loading