Skip to content

Commit

Permalink
hydrogen bond code clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
brittyscience committed Oct 7, 2024
1 parent e84b9f4 commit 660758e
Showing 1 changed file with 56 additions and 121 deletions.
177 changes: 56 additions & 121 deletions mdagent/tools/base_tools/analysis_tools/hydrogen_bonding_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def compute_baker_hubbard(traj, freq=0.1):
Returns:
The hydrogen bonds found using the Baker-Hubbard method.
"""
frequency = float(freq)
try:
frequency = float(freq)
except ValueError:
Expand Down Expand Up @@ -117,11 +116,19 @@ def plot_and_save_hb_plot(

class HydrogenBondTool(BaseTool):
name = "hydrogen_bond_tool"
description = """Identifies hydrogen bonds using different methods;
Baker-Hubbard and Wernet-Nilsson. Input a trajectory file ID and a method (either
baker_hubbard or wernet_nilsson). If baker_hubbard is used, a frequency must
be provided as a float. Optionally provide the topology file ID. Output is a
file and plot of the hydrogen bonds found."""
description = (
"Identifies hydrogen bonds using different methods: Baker-Hubbard or "
" Wernet-Nilsson, and plots the results from the provided trajectory data."
"\nInputs: \n"
"\t(str) File ID for the trajectory file. \n"
"\t(str, optional) File ID for the topology file. \n"
"\t(str) Method to use for identification ('baker_hubbard' or "
"'wernet_nilsson'). \n"
"\t(float, optional) Frequency for the Baker-Hubbard method (default: 0.1). \n"
"\nOutputs: \n"
"\t(str) Result of the analysis indicating success or failure, along with file"
"IDs for results and plots."
)

path_registry: PathRegistry | None = None

Expand All @@ -134,8 +141,10 @@ def _run(
traj_file: str,
top_file: str | None = None,
method: str = "baker_hubbard",
freq: str | None = "0.1",
freq: str = "0.1",
) -> str:
if self.path_registry is None:
raise ValueError("Path registry is not set.")
try:
traj = load_single_traj(self.path_registry, top_file, traj_file)
if not traj:
Expand All @@ -153,63 +162,55 @@ def _run(
# Count the number of hydrogen bonds for each frame
hb_counts = np.array([len(frame) for frame in result])

if self.path_registry is not None:
result_file_id = save_hb_results(
{"results": [list(item) for item in result]},
method,
self.path_registry,
)
result_file_id = save_hb_results(
{"results": [list(item) for item in result]},
method,
self.path_registry,
)

plot_hist_file_id = plot_and_save_hb_plot(
hb_counts,
title=f"{method.capitalize()} Histogram",
plot_type="histogram",
method=method,
path_registry=self.path_registry,
)
plot_hist_file_id = plot_and_save_hb_plot(
hb_counts,
title=f"{method.capitalize()} Histogram",
plot_type="histogram",
method=method,
path_registry=self.path_registry,
)

plot_time_series_file_id = plot_and_save_hb_plot(
hb_counts,
title=f"{method.capitalize()} Time Series",
plot_type="time_series",
method=method,
path_registry=self.path_registry,
ylabel="Bond Energy",
)
return (
"Succeeded. Analysis completed, results saved to file and plot"
"saved. "
f"Results file: {result_file_id}, "
f"Histogram plot: {plot_hist_file_id}, "
f"Time series plot: {plot_time_series_file_id}"
)
else:
return """Failed. Path registry helps track
file locations and it is not set up. Please make sure it is set up
before running this tool."""
plot_time_series_file_id = plot_and_save_hb_plot(
hb_counts,
title=f"{method.capitalize()} Time Series",
plot_type="time_series",
method=method,
path_registry=self.path_registry,
ylabel="Bond Energy",
)
return (
"Succeeded. Analysis completed, results saved to file and plot"
"saved. "
f"Results file: {result_file_id}, "
f"Histogram plot: {plot_hist_file_id}, "
f"Time series plot: {plot_time_series_file_id}"
)

except Exception as e:
return f"Failed. {type(e).__name__}: {e}"

def save_results_to_file(self, results: dict, file_name: str) -> None:
with open(file_name, "w") as f:
json.dump(results, f)
if self.path_registry:
file_id = self.path_registry.get_fileid(file_name, FileType.RECORD)
self.path_registry.map_path(
file_id,
file_name,
description=f"Results saved to {file_name}",
)


class KabschSander(BaseTool):
name = "kabsch_sander"
description = """This function compute the hydrogen bond energy between each pair
of residues in every frame. THe input isthe file ID of a traj file containing
MD data and optional top file with the molecular structure data. The
output is a string telling the user whether the simulation was a success
or a failure."""
description: str = (
"Compute the hydrogen bond energy between each pair of residues"
"in every frame of the trajectory."
)

"\n Parameters: \n"
"\t(str) traj_file: The file ID of the trajectory file containing "
"molecular dynamics (MD) data.\n"
"\t(str, optional)top_file: The optional topology file ID"
"providing molecular structure data. Default is None. \n"

"\n Returns:\n"
"\t(str): A message indicating whether the analysis was successful or failed."

path_registry: PathRegistry | None = None

Expand Down Expand Up @@ -278,69 +279,3 @@ def save_results_to_file(self, results: dict, file_name: str) -> None:
file_name,
description=f"Results saved to {file_name}",
)


def plot_time_series(
data, title: str = "Time Series Plot", ylabel: str = "Value", save_path=None
):
plt.figure(figsize=(10, 6))
plt.plot(data, label="Hydrogen Bonds")
plt.xlabel("Time (frames)")
plt.ylabel(ylabel)
plt.title(title)
plt.legend()
plt.grid(True)

if save_path:
plt.savefig(save_path)
if PathRegistry.get_instance() is not None:
PathRegistry.get_instance().register_path(save_path)
else:
plt.show()
plt.close()


def plot_histogram(
data,
bins: int = 10,
title: str = "Histogram",
xlabel: str = "Value",
save_path=None,
):
plt.figure(figsize=(10, 6))
plt.hist(data, bins=bins, edgecolor="black")
plt.xlabel(xlabel)
plt.ylabel("Frequency")
plt.title(title)
plt.grid(True)

if save_path:
plt.savefig(save_path)
if PathRegistry.get_instance() is not None:
PathRegistry.get_instance().register_path(save_path)
else:
plt.show()
plt.close()


if __name__ == "__main__":
example_data = np.random.randn(100)

# Plot time series
plot_and_save_hb_plot(
example_data,
title="Example Time Series Plot",
plot_type="time_series",
method="example",
path_registry=PathRegistry.get_instance(),
ylabel="Value",
)

# Plot histogram
plot_and_save_hb_plot(
example_data,
title="Example Histogram",
plot_type="histogram",
method="example",
path_registry=PathRegistry.get_instance(),
)

0 comments on commit 660758e

Please sign in to comment.