From 847d07c7f166ab1b5ca6aa10f02c2fbb4919af73 Mon Sep 17 00:00:00 2001 From: Sarthak Srivastava Date: Wed, 14 Aug 2024 06:00:10 +0530 Subject: [PATCH] regression tests update --- .../tools/tests/test_liv_plot.py | 472 +++++++++--------- 1 file changed, 246 insertions(+), 226 deletions(-) diff --git a/tardis/visualization/tools/tests/test_liv_plot.py b/tardis/visualization/tools/tests/test_liv_plot.py index 899aab13b5f..4b69c5c44ef 100644 --- a/tardis/visualization/tools/tests/test_liv_plot.py +++ b/tardis/visualization/tools/tests/test_liv_plot.py @@ -1,41 +1,24 @@ +from copy import deepcopy +from itertools import product + +import astropy.units as u import numpy as np import pytest -from numpy import testing as npt -from pandas import testing as pdt -from copy import deepcopy from matplotlib.collections import PolyCollection from matplotlib.lines import Line2D + from tardis.base import run_tardis +from tardis.io.util import HDFWriterMixin from tardis.visualization.tools.liv_plot import LIVPlotter from tardis.tests.fixtures.regression_data import RegressionData -def make_valid_name(testid): - """ - Sanitize pytest IDs to make them valid HDF group names. - - Parameters - ---------- - testid : str - ID to sanitize. - - Returns - ------- - testid : str - Sanitized ID. - """ - return "_" + testid.replace("-", "_") - - -def convert_to_native_type(obj): - if isinstance(obj, dict): - return {k: convert_to_native_type(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_to_native_type(i) for i in obj] - elif isinstance(obj, np.int64): - return int(obj) - else: - return obj +class PlotDataHDF(HDFWriterMixin): + def __init__(self, **kwargs): + self.hdf_properties = [] + for key, value in kwargs.items(): + setattr(self, key, value) + self.hdf_properties.append(key) @pytest.fixture(scope="module") @@ -88,251 +71,288 @@ def plotter(simulation_simple): class TestLIVPlotter: """Test the LIVPlotter class.""" - @pytest.mark.parametrize( - "species_list", [["Si II", "Ca II", "C", "Fe I-V"]] + regression_data = None + species_list = [["Si II", "Ca II", "C", "Fe I-V"], None] + nelements = [1, None] + packets_mode = ["virtual", "real"] + num_bins = [10, 25] + velocity_range = [(18000, 25000)] + + combinations = list( + product( + species_list, + packets_mode, + nelements, + num_bins, + velocity_range, + ) ) - @pytest.mark.parametrize("packets_mode", ["virtual", "real"]) - @pytest.mark.parametrize("nelements", [1, None]) - def test_parse_species_list( - self, - request, - plotter, - species_list, - packets_mode, - nelements, - regression_data, - ): - """ - Test _parse_species_list method. - - Parameters - ---------- - request : _pytest.fixtures.SubRequest - plotter : tardis.visualization.tools.liv_plot.LIVPlotter - species_list : List of species to plot. - packets_mode : str, Packet mode, either 'virtual' or 'real'. - nelements : int, Number of elements to include in plot. - """ - subgroup_name = make_valid_name(request.node.callspec.id) + + @pytest.fixture(scope="class", params=combinations) + def plotter_parse_species_list(self, request, plotter): + ( + _, + packets_mode, + nelements, + _, + _, + ) = request.param plotter._parse_species_list( - species_list=species_list, packets_mode=packets_mode, + species_list=self.species_list[0], nelements=nelements, ) - regression_data_fname = ( - f"livplotter_parse_species_list_{subgroup_name}.h5" - ) - - expected = pd.read_hdf(regression_data_fname, "species_list") - pdt.assert_frame_equal(plotter._species_list, expected) - - expected = pd.read_hdf(regression_data_fname, "keep_colour") - pdt.assert_frame_equal(plotter._keep_colour, expected) - - expected = pd.read_hdf(regression_data_fname, "species_mapped") - pdt.assert_frame_equal( - convert_to_native_type(plotter._species_mapped), expected - ) + return plotter - @pytest.mark.parametrize("packets_mode", ["virtual", "real"]) @pytest.mark.parametrize( - "species_list", [["Si II", "Ca II", "C", "Fe I-V"]] + "attribute", + [ + "_species_list", + "_keep_colour", + "_species_mapped", + ], ) - @pytest.mark.parametrize("cmapname", ["jet"]) - @pytest.mark.parametrize("num_bins", [10, 25]) - @pytest.mark.parametrize("nelements", [1, None]) - def test_prepare_plot_data( + def test_parse_species_list( self, request, - plotter, - packets_mode, - species_list, - cmapname, - num_bins, - nelements, - regression_data, + plotter_parse_species_list, + attribute, ): - """ - Test _prepare_plot_data method. - - Parameters - ---------- - request : _pytest.fixtures.SubRequest - plotter : tardis.visualization.tools.liv_plot.LIVPlotter - packets_mode : str, Packet mode, either 'virtual' or 'real'. - species_list : list of species to plot - cmapname : str - num_bins : int, Number of bins for regrouping within the same range. - nelements : int, Number of elements to include in plot. - """ - subgroup_name = make_valid_name(request.node.callspec.id) + regression_data = RegressionData(request) + if attribute == "_species_mapped": + plot_object = getattr(plotter_parse_species_list, attribute) + plot_object = [ + item + for sublist in list(plot_object.values()) + for item in sublist + ] + data = regression_data.sync_ndarray(plot_object) + np.testing.assert_allclose(plot_object, data) + else: + plot_object = getattr(plotter_parse_species_list, attribute) + data = regression_data.sync_ndarray(plot_object) + np.testing.assert_allclose(plot_object, data) + + @pytest.fixture(scope="class", params=combinations) + def plotter_prepare_plot_data(self, request, plotter): + ( + species_list, + packets_mode, + nelements, + num_bins, + _, + ) = request.param plotter._prepare_plot_data( packets_mode=packets_mode, species_list=species_list, - cmapname=cmapname, + cmapname="jet", num_bins=num_bins, nelements=nelements, ) - plot_data_numeric = [ - [q.value for q in row] for row in plotter.plot_data - ] - flat_list = [item for sublist in plot_data_numeric for item in sublist] - plot_data_list = np.array(flat_list) - regression_data_fname = ( - f"livplotter_prepare_plot_data_{subgroup_name}.h5" - ) - - expected = pd.read_hdf(regression_data_fname, "plot_data") - pdt.assert_frame_equal(plot_data_list, expected) - - expected = pd.read_hdf(regression_data_fname, "plot_colors") - pdt.assert_frame_equal(plotter.plot_colors, expected) - - expected = pd.read_hdf(regression_data_fname, "new_bin_edges") - pdt.assert_frame_equal(plotter.new_bin_edges, expected) + return plotter @pytest.mark.parametrize( - "species_list", [["Si II", "Ca II", "C", "Fe I-V"], None] + "attribute", + [ + "plot_data", + "plot_colors", + "new_bin_edges", + ], ) - @pytest.mark.parametrize("nelements", [1, None]) - @pytest.mark.parametrize("packets_mode", ["virtual", "real"]) - @pytest.mark.parametrize("xlog_scale", [True, False]) - @pytest.mark.parametrize("ylog_scale", [True, False]) - @pytest.mark.parametrize("num_bins", [10, 25]) - @pytest.mark.parametrize("velocity_range", [(18000, 25000)]) - def test_generate_plot_mpl( + def test_prepare_plot_data( self, + plotter_prepare_plot_data, request, - plotter, - species_list, - nelements, - packets_mode, - xlog_scale, - ylog_scale, - num_bins, - velocity_range, - regression_data, + attribute, ): - """ - Test generate_plot_mpl method. - - Parameters - ---------- - request : _pytest.fixtures.SubRequest - plotter : tardis.visualization.tools.liv_plot.LIVPlotter - species_list : List of species to plot. - nelements : int, Number of elements to include in plot. - packets_mode : str, Packet mode, either 'virtual' or 'real'. - xlog_scale : bool, If True, x-axis is scaled logarithmically. - ylog_scale : bool, If True, y-axis is scaled logarithmically. - num_bins : int, Number of bins for regrouping within the same range. - velocity_range : tuple, Limits for the x-axis. - """ - subgroup_name = make_valid_name("mpl" + request.node.callspec.id) + regression_data = RegressionData(request) + if attribute == "plot_data" or attribute == "plot_colors": + plot_object = getattr(plotter_prepare_plot_data, attribute) + plot_object = [item for sublist in plot_object for item in sublist] + if all(isinstance(item, u.Quantity) for item in plot_object): + plot_object = [item.value for item in plot_object] + data = regression_data.sync_ndarray(plot_object) + np.testing.assert_allclose(plot_object, data) + else: + plot_object = getattr(plotter_prepare_plot_data, attribute) + plot_object = plot_object.value + data = regression_data.sync_ndarray(plot_object) + np.testing.assert_allclose(plot_object, data) + + @pytest.fixture(scope="function", params=combinations) + def plotter_generate_plot_mpl(self, request, plotter): + ( + species_list, + packets_mode, + nelements, + num_bins, + velocity_range, + ) = request.param + fig = plotter.generate_plot_mpl( species_list=species_list, nelements=nelements, packets_mode=packets_mode, - xlog_scale=xlog_scale, - ylog_scale=ylog_scale, num_bins=num_bins, velocity_range=velocity_range, ) - fig_data = { + return fig, plotter + + @pytest.fixture(scope="function") + def generate_plot_mpl_hdf(self, plotter_generate_plot_mpl): + fig, plotter = plotter_generate_plot_mpl + + color_list = [ + item for subitem in plotter._color_list for item in subitem + ] + property_group = { "_species_name": plotter._species_name, - "_color_list": plotter._color_list, - "step_x": plotter.step_x, + "_color_list": color_list, + "step_x": plotter.step_x.value, "step_y": plotter.step_y, - "fig_data": [], } - - for index, data in enumerate(fig.get_children()): - trace_data = {} + for index1, data in enumerate(fig.get_children()): if isinstance(data.get_label(), str): - trace_data["label"] = data.get_label() - if isinstance(data, PolyCollection): - trace_data["paths"] = [ - path.vertices for path in data.get_paths() - ] + property_group["label" + str(index1)] = ( + data.get_label().encode() + ) + # save line plots if isinstance(data, Line2D): - trace_data["xydata"] = data.get_xydata() - trace_data["path"] = data.get_path().vertices - fig_data["fig_data"].append(trace_data) + property_group["data" + str(index1)] = data.get_xydata() + property_group["linepath" + str(index1)] = ( + data.get_path().vertices + ) - regression_data_fname = ( - f"livplotter_generate_plot_mpl_{subgroup_name}.h5" - ) + # save artists which correspond to element contributions + if isinstance(data, PolyCollection): + for index2, path in enumerate(data.get_paths()): + property_group[ + "polypath" + "ind_" + str(index1) + "ind_" + str(index2) + ] = path.vertices - expected = pd.read_hdf(regression_data_fname, "fig_data") - pdt.assert_frame_equal(fig_data, expected) + plot_data = PlotDataHDF(**property_group) + return plot_data - @pytest.mark.parametrize( - "species_list", [["Si II", "Ca II", "C", "Fe I-V"], None] - ) - @pytest.mark.parametrize("nelements", [1, None]) - @pytest.mark.parametrize("packets_mode", ["virtual", "real"]) - @pytest.mark.parametrize("xlog_scale", [True, False]) - @pytest.mark.parametrize("ylog_scale", [True, False]) - @pytest.mark.parametrize("num_bins", [10, 25]) - @pytest.mark.parametrize("velocity_range", [(18000, 25000)]) - def test_generate_plot_ply( - self, - request, - plotter, - species_list, - nelements, - packets_mode, - xlog_scale, - ylog_scale, - num_bins, - velocity_range, - regression_data, + def test_generate_plot_mpl( + self, generate_plot_mpl_hdf, plotter_generate_plot_mpl, request ): - """ - Test generate_plot_ply method. - - Parameters - ---------- - request : _pytest.fixtures.SubRequest - plotter : tardis.visualization.tools.liv_plot.LIVPlotter - species_list : List of species to plot. - nelements : int, Number of elements to include in plot. - packets_mode : str, Packet mode, either 'virtual' or 'real'. - xlog_scale : bool, If True, x-axis is scaled logarithmically. - ylog_scale : bool, If True, y-axis is scaled logarithmically. - num_bins : int, Number of bins for regrouping within the same range. - velocity_range : tuple, Limits for the x-axis. - """ - subgroup_name = make_valid_name("ply" + request.node.callspec.id) + fig, _ = plotter_generate_plot_mpl + regression_data = RegressionData(request) + expected = regression_data.sync_hdf_store(generate_plot_mpl_hdf) + for item in ["_species_name", "_color_list", "step_x", "step_y"]: + np.testing.assert_array_equal( + expected.get("plot_data_hdf/" + item).values.flatten(), + getattr(generate_plot_mpl_hdf, item), + ) + labels = expected["plot_data_hdf/scalars"] + for index1, data in enumerate(fig.get_children()): + if isinstance(data.get_label(), str): + assert ( + getattr(labels, "label" + str(index1)).decode() + == data.get_label() + ) + # save line plots + if isinstance(data, Line2D): + np.testing.assert_allclose( + data.get_xydata(), + expected.get("plot_data_hdf/" + "data" + str(index1)), + ) + np.testing.assert_allclose( + data.get_path().vertices, + expected.get("plot_data_hdf/" + "linepath" + str(index1)), + ) + # save artists which correspond to element contributions + if isinstance(data, PolyCollection): + for index2, path in enumerate(data.get_paths()): + np.testing.assert_almost_equal( + path.vertices, + expected.get( + "plot_data_hdf/" + + "polypath" + + "ind_" + + str(index1) + + "ind_" + + str(index2) + ), + ) + + @pytest.fixture(scope="function", params=combinations) + def plotter_generate_plot_ply(self, request, plotter): + ( + species_list, + packets_mode, + nelements, + num_bins, + velocity_range, + ) = request.param + fig = plotter.generate_plot_ply( species_list=species_list, nelements=nelements, packets_mode=packets_mode, - xlog_scale=xlog_scale, - ylog_scale=ylog_scale, num_bins=num_bins, velocity_range=velocity_range, ) - fig_data = { + return fig, plotter + + @pytest.fixture(scope="function") + def generate_plot_plotly_hdf(self, plotter_generate_plot_ply, request): + fig, plotter = plotter_generate_plot_ply + + color_list = [ + item for subitem in plotter._color_list for item in subitem + ] + property_group = { "_species_name": plotter._species_name, - "_color_list": plotter._color_list, - "step_x": plotter.step_x, + "_color_list": color_list, + "step_x": plotter.step_x.value, "step_y": plotter.step_y, - "fig_data": [], } - for index, data in enumerate(fig.data): - trace_data = {} - if isinstance(data.name, str): - trace_data["label"] = data.name - if isinstance(data, go.Scatter): - trace_data["x"] = data.x - trace_data["y"] = data.y - fig_data["fig_data"].append(trace_data) - - regression_data_fname = ( - f"livplotter_generate_plot_ply_{subgroup_name}.h5" - ) + group = "_" + str(index) + if data.stackgroup: + property_group[group + "stackgroup"] = data.stackgroup.encode() + if data.name: + property_group[group + "name"] = data.name.encode() + property_group[group + "x"] = data.x + property_group[group + "y"] = data.y + plot_data = PlotDataHDF(**property_group) + return plot_data + + def test_generate_plot_ply( + self, generate_plot_plotly_hdf, plotter_generate_plot_ply, request + ): + fig, _ = plotter_generate_plot_ply + regression_data = RegressionData(request) + expected = regression_data.sync_hdf_store(generate_plot_plotly_hdf) - expected = pd.read_hdf(regression_data_fname, "fig_data") - pdt.assert_frame_equal(fig_data, expected) + for item in ["_species_name", "_color_list", "step_x", "step_y"]: + np.testing.assert_array_equal( + expected.get("plot_data_hdf/" + item).values.flatten(), + getattr(generate_plot_plotly_hdf, item), + ) + + for index, data in enumerate(fig.data): + group = "plot_data_hdf/" + "_" + str(index) + if data.stackgroup: + assert ( + data.stackgroup + == getattr( + expected["/plot_data_hdf/scalars"], + "_" + str(index) + "stackgroup", + ).decode() + ) + if data.name: + assert ( + data.name + == getattr( + expected["/plot_data_hdf/scalars"], + "_" + str(index) + "name", + ).decode() + ) + np.testing.assert_allclose( + data.x, expected.get(group + "x").values.flatten() + ) + np.testing.assert_allclose( + data.y, expected.get(group + "y").values.flatten() + )