Skip to content

Commit

Permalink
Plot test (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
SamCox822 authored Dec 21, 2023
1 parent 4ea8dbe commit c74d2cf
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
30 changes: 13 additions & 17 deletions mdagent/tools/base_tools/analysis_tools/plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,44 +24,40 @@ def plot_data(data, headers, matched_headers):
# Get the first matched header
if matched_headers:
time_or_step = matched_headers[0][1]
xlab = "step" if "step" in time_or_step.lower() else "time"
else:
print("No 'step' or 'time' headers found.")
return

failed_headers = []

created_plots = []
# For each header (except the time/step one), plot time/step vs that header
header_count = 0
for header in headers:
if header != time_or_step:
header_count += 1
try:
# Extract the data for the x and y axes
x = [float(row[time_or_step]) for row in data]
y = [float(row[header]) for row in data]

if "step" in time_or_step.lower():
xlab = "step"
if "(" in header:
header_lab = (header.split("(")[0]).strip()
# Generate the plot
header_lab = (
header.split("(")[0].strip() if "(" in header else header
).lower()
plot_name = f"{xlab}_vs_{header_lab}.png"

# Generate and save the plot
plt.figure()
plt.plot(x, y)
plt.xlabel(xlab)
plt.ylabel(header)
plt.title(f"{xlab} vs {header_lab}")

# Save the figure
plt.savefig(f"{xlab}_vs_{header_lab}.png")
plt.savefig(plot_name)
plt.close()
created_plots.append(f"{xlab}_vs_{header_lab}.png")
except ValueError: # If data cannot be converted to float

created_plots.append(plot_name)
except ValueError:
failed_headers.append(header)

# If all plots failed, raise an exception
if len(failed_headers) == len(headers) - header_count:
if len(failed_headers) == len(headers) - 1: # -1 to account for time_or_step header
raise Exception("All plots failed due to non-numeric data.")

return ", ".join(created_plots)


Expand Down
56 changes: 55 additions & 1 deletion tests/test_fxns.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from unittest.mock import mock_open, patch
from unittest.mock import MagicMock, mock_open, patch

import pytest

Expand All @@ -10,6 +10,7 @@
VisFunctions,
get_pdb,
)
from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv
from mdagent.utils import PathRegistry

warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources")
Expand Down Expand Up @@ -60,6 +61,59 @@ def get_registry():
return PathRegistry()


def test_process_csv():
mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25"
mock_reader = MagicMock()
mock_reader.fieldnames = ["Time", "Value1", "Value2"]
mock_reader.__iter__.return_value = iter(
[
{"Time": "1", "Value1": "10", "Value2": "20"},
{"Time": "2", "Value1": "15", "Value2": "25"},
]
)

with patch("builtins.open", mock_open(read_data=mock_csv_content)):
with patch("csv.DictReader", return_value=mock_reader):
data, headers, matched_headers = process_csv("mock_file.csv")

assert headers == ["Time", "Value1", "Value2"]
assert len(matched_headers) == 1
assert matched_headers[0][1] == "Time"
assert len(data) == 2
assert data[0]["Time"] == "1" and data[0]["Value1"] == "10"


def test_plot_data():
# Test successful plot generation
data_success = [
{"Time": "1", "Value1": "10", "Value2": "20"},
{"Time": "2", "Value1": "15", "Value2": "25"},
]
headers = ["Time", "Value1", "Value2"]
matched_headers = [(0, "Time")]

with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch(
"matplotlib.pyplot.xlabel"
), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch(
"matplotlib.pyplot.savefig"
), patch(
"matplotlib.pyplot.close"
):
created_plots = plot_data(data_success, headers, matched_headers)
assert "time_vs_value1.png" in created_plots
assert "time_vs_value2.png" in created_plots

# Test failure due to non-numeric data
data_failure = [
{"Time": "1", "Value1": "A", "Value2": "B"},
{"Time": "2", "Value1": "C", "Value2": "D"},
]

with pytest.raises(Exception) as excinfo:
plot_data(data_failure, headers, matched_headers)
assert "All plots failed due to non-numeric data." in str(excinfo.value)


@pytest.mark.skip(reason="molrender is not pip installable")
def test_run_molrender(path_to_cif, vis_fxns):
result = vis_fxns.run_molrender(path_to_cif)
Expand Down

0 comments on commit c74d2cf

Please sign in to comment.