From b33a42fe4d68c9ae4e901b9db45a032c7629f329 Mon Sep 17 00:00:00 2001 From: sgbaird Date: Fri, 17 Jun 2022 15:36:37 -0600 Subject: [PATCH] add tests for rgb_scaling=False --- tests/xtal2png_test.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/xtal2png_test.py b/tests/xtal2png_test.py index 7e3052e..19ed586 100644 --- a/tests/xtal2png_test.py +++ b/tests/xtal2png_test.py @@ -4,6 +4,7 @@ from warnings import warn +import numpy as np import plotly.express as px from numpy.testing import assert_allclose, assert_array_equal, assert_equal from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher @@ -125,6 +126,21 @@ def test_structures_to_arrays_single(): return data +def test_structures_to_arrays_zero_one(): + xc = XtalConverter(relax_on_decode=False) + data, _, _ = xc.structures_to_arrays(example_structures, rgb_scaling=False) + + if np.min(data) < 0.0: + raise ValueError( + f"minimum is less than 0 when rgb_output=False: {np.min(data)}" + ) + if np.max(data) > 1.0: + raise ValueError( + f"maximum is greater than 1 when rgb_output=False: {np.max(data)}" + ) + return data + + def test_arrays_to_structures(): xc = XtalConverter(relax_on_decode=False) data, id_data, id_mapper = xc.structures_to_arrays(example_structures) @@ -133,6 +149,16 @@ def test_arrays_to_structures(): return structures +def test_arrays_to_structures_zero_one(): + xc = XtalConverter(relax_on_decode=False) + data, id_data, id_mapper = xc.structures_to_arrays( + example_structures, rgb_scaling=False + ) + structures = xc.arrays_to_structures(data, id_data, id_mapper, rgb_scaling=False) + assert_structures_approximate_match(example_structures, structures) + return structures + + def test_arrays_to_structures_single(): xc = XtalConverter(relax_on_decode=False) data, id_data, id_mapper = xc.structures_to_arrays([example_structures[0]]) @@ -295,6 +321,8 @@ def test_plot_and_save(): if __name__ == "__main__": + test_structures_to_arrays_zero_one() + test_arrays_to_structures_zero_one() test_relax_on_decode() test_primitive_decoding() test_primitive_encoding()