Skip to content

Commit

Permalink
array not equal fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthak-dv committed Aug 20, 2024
1 parent 1ee7466 commit 0e8a8a6
Showing 1 changed file with 16 additions and 41 deletions.
57 changes: 16 additions & 41 deletions tardis/visualization/tools/tests/test_liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,20 +180,18 @@ def generate_plot_mpl_hdf(self, plotter_generate_plot_mpl):
property_group = {
"_species_name": plotter._species_name,
"_color_list": color_list,
"step_x": plotter.step_x.value,
"step_y": plotter.step_y,
}
for index1, data in enumerate(fig.get_children()):
if isinstance(data.get_label(), str):
property_group["label" + str(index1)] = (
data.get_label().encode()
)
property_group[
"label" + str(index1)
] = data.get_label().encode()
# save line plots
if isinstance(data, Line2D):
property_group["data" + str(index1)] = data.get_xydata()
property_group["linepath" + str(index1)] = (
data.get_path().vertices
)
property_group[
"linepath" + str(index1)
] = data.get_path().vertices

# save artists which correspond to element contributions
if isinstance(data, PolyCollection):
Expand All @@ -211,21 +209,11 @@ def test_generate_plot_mpl(
fig, _ = plotter_generate_plot_mpl
regression_data = RegressionData(request)
expected = regression_data.sync_hdf_store(generate_plot_mpl_hdf)
for item in ["_species_name", "_color_list", "step_x", "step_y"]:
expected_values = expected.get(
"plot_data_hdf/" + item
).values.flatten()
actual_values = getattr(generate_plot_mpl_hdf, item)

if np.issubdtype(expected_values.dtype, np.number):
np.testing.assert_allclose(
expected_values,
actual_values,
rtol=1e-3,
atol=1e-5,
)
else:
assert np.array_equal(expected_values, actual_values)
for item in ["_species_name", "_color_list"]:
np.testing.assert_array_equal(
expected.get("plot_data_hdf/" + item).values.flatten(),
getattr(generate_plot_mpl_hdf, item),
)
labels = expected["plot_data_hdf/scalars"]
for index1, data in enumerate(fig.get_children()):
if isinstance(data.get_label(), str):
Expand Down Expand Up @@ -290,8 +278,6 @@ def generate_plot_plotly_hdf(self, plotter_generate_plot_ply, request):
property_group = {
"_species_name": plotter._species_name,
"_color_list": color_list,
"step_x": plotter.step_x.value,
"step_y": plotter.step_y,
}
for index, data in enumerate(fig.data):
group = "_" + str(index)
Expand All @@ -311,22 +297,11 @@ def test_generate_plot_ply(
regression_data = RegressionData(request)
expected = regression_data.sync_hdf_store(generate_plot_plotly_hdf)

for item in ["_species_name", "_color_list", "step_x", "step_y"]:
expected_values = expected.get(
"plot_data_hdf/" + item
).values.flatten()
actual_values = getattr(generate_plot_plotly_hdf, item)

if np.issubdtype(expected_values.dtype, np.number):
np.testing.assert_allclose(
expected_values,
actual_values,
rtol=1e-3,
atol=1e-5,
)
else:
assert np.array_equal(expected_values, actual_values)

for item in ["_species_name", "_color_list"]:
np.testing.assert_array_equal(
expected.get("plot_data_hdf/" + item).values.flatten(),
getattr(generate_plot_plotly_hdf, item),
)
for index, data in enumerate(fig.data):
group = "plot_data_hdf/" + "_" + str(index)
if data.stackgroup:
Expand Down

0 comments on commit 0e8a8a6

Please sign in to comment.