diff --git a/mdagent/tools/base_tools/analysis_tools/plot_tools.py b/mdagent/tools/base_tools/analysis_tools/plot_tools.py index f3a8c4fc..7c20784d 100644 --- a/mdagent/tools/base_tools/analysis_tools/plot_tools.py +++ b/mdagent/tools/base_tools/analysis_tools/plot_tools.py @@ -24,44 +24,40 @@ def plot_data(data, headers, matched_headers): # Get the first matched header if matched_headers: time_or_step = matched_headers[0][1] + xlab = "step" if "step" in time_or_step.lower() else "time" else: print("No 'step' or 'time' headers found.") return failed_headers = [] - created_plots = [] - # For each header (except the time/step one), plot time/step vs that header - header_count = 0 for header in headers: if header != time_or_step: - header_count += 1 try: - # Extract the data for the x and y axes x = [float(row[time_or_step]) for row in data] y = [float(row[header]) for row in data] - if "step" in time_or_step.lower(): - xlab = "step" - if "(" in header: - header_lab = (header.split("(")[0]).strip() - # Generate the plot + header_lab = ( + header.split("(")[0].strip() if "(" in header else header + ).lower() + plot_name = f"{xlab}_vs_{header_lab}.png" + + # Generate and save the plot plt.figure() plt.plot(x, y) plt.xlabel(xlab) plt.ylabel(header) plt.title(f"{xlab} vs {header_lab}") - - # Save the figure - plt.savefig(f"{xlab}_vs_{header_lab}.png") + plt.savefig(plot_name) plt.close() - created_plots.append(f"{xlab}_vs_{header_lab}.png") - except ValueError: # If data cannot be converted to float + + created_plots.append(plot_name) + except ValueError: failed_headers.append(header) - # If all plots failed, raise an exception - if len(failed_headers) == len(headers) - header_count: + if len(failed_headers) == len(headers) - 1: # -1 to account for time_or_step header raise Exception("All plots failed due to non-numeric data.") + return ", ".join(created_plots) diff --git a/tests/test_fxns.py b/tests/test_fxns.py index 04cf476a..ac2a9410 100644 --- a/tests/test_fxns.py +++ b/tests/test_fxns.py @@ -1,6 +1,6 @@ import os import warnings -from unittest.mock import mock_open, patch +from unittest.mock import MagicMock, mock_open, patch import pytest @@ -10,6 +10,7 @@ VisFunctions, get_pdb, ) +from mdagent.tools.base_tools.analysis_tools.plot_tools import plot_data, process_csv from mdagent.utils import PathRegistry warnings.filterwarnings("ignore", category=DeprecationWarning, module="pkg_resources") @@ -60,6 +61,59 @@ def get_registry(): return PathRegistry() +def test_process_csv(): + mock_csv_content = "Time,Value1,Value2\n1,10,20\n2,15,25" + mock_reader = MagicMock() + mock_reader.fieldnames = ["Time", "Value1", "Value2"] + mock_reader.__iter__.return_value = iter( + [ + {"Time": "1", "Value1": "10", "Value2": "20"}, + {"Time": "2", "Value1": "15", "Value2": "25"}, + ] + ) + + with patch("builtins.open", mock_open(read_data=mock_csv_content)): + with patch("csv.DictReader", return_value=mock_reader): + data, headers, matched_headers = process_csv("mock_file.csv") + + assert headers == ["Time", "Value1", "Value2"] + assert len(matched_headers) == 1 + assert matched_headers[0][1] == "Time" + assert len(data) == 2 + assert data[0]["Time"] == "1" and data[0]["Value1"] == "10" + + +def test_plot_data(): + # Test successful plot generation + data_success = [ + {"Time": "1", "Value1": "10", "Value2": "20"}, + {"Time": "2", "Value1": "15", "Value2": "25"}, + ] + headers = ["Time", "Value1", "Value2"] + matched_headers = [(0, "Time")] + + with patch("matplotlib.pyplot.figure"), patch("matplotlib.pyplot.plot"), patch( + "matplotlib.pyplot.xlabel" + ), patch("matplotlib.pyplot.ylabel"), patch("matplotlib.pyplot.title"), patch( + "matplotlib.pyplot.savefig" + ), patch( + "matplotlib.pyplot.close" + ): + created_plots = plot_data(data_success, headers, matched_headers) + assert "time_vs_value1.png" in created_plots + assert "time_vs_value2.png" in created_plots + + # Test failure due to non-numeric data + data_failure = [ + {"Time": "1", "Value1": "A", "Value2": "B"}, + {"Time": "2", "Value1": "C", "Value2": "D"}, + ] + + with pytest.raises(Exception) as excinfo: + plot_data(data_failure, headers, matched_headers) + assert "All plots failed due to non-numeric data." in str(excinfo.value) + + @pytest.mark.skip(reason="molrender is not pip installable") def test_run_molrender(path_to_cif, vis_fxns): result = vis_fxns.run_molrender(path_to_cif)