Skip to content

Commit b33a42f

Browse files
committed
add tests for rgb_scaling=False
1 parent a06542f commit b33a42f

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tests/xtal2png_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from warnings import warn
66

7+
import numpy as np
78
import plotly.express as px
89
from numpy.testing import assert_allclose, assert_array_equal, assert_equal
910
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher
@@ -125,6 +126,21 @@ def test_structures_to_arrays_single():
125126
return data
126127

127128

129+
def test_structures_to_arrays_zero_one():
130+
xc = XtalConverter(relax_on_decode=False)
131+
data, _, _ = xc.structures_to_arrays(example_structures, rgb_scaling=False)
132+
133+
if np.min(data) < 0.0:
134+
raise ValueError(
135+
f"minimum is less than 0 when rgb_output=False: {np.min(data)}"
136+
)
137+
if np.max(data) > 1.0:
138+
raise ValueError(
139+
f"maximum is greater than 1 when rgb_output=False: {np.max(data)}"
140+
)
141+
return data
142+
143+
128144
def test_arrays_to_structures():
129145
xc = XtalConverter(relax_on_decode=False)
130146
data, id_data, id_mapper = xc.structures_to_arrays(example_structures)
@@ -133,6 +149,16 @@ def test_arrays_to_structures():
133149
return structures
134150

135151

152+
def test_arrays_to_structures_zero_one():
153+
xc = XtalConverter(relax_on_decode=False)
154+
data, id_data, id_mapper = xc.structures_to_arrays(
155+
example_structures, rgb_scaling=False
156+
)
157+
structures = xc.arrays_to_structures(data, id_data, id_mapper, rgb_scaling=False)
158+
assert_structures_approximate_match(example_structures, structures)
159+
return structures
160+
161+
136162
def test_arrays_to_structures_single():
137163
xc = XtalConverter(relax_on_decode=False)
138164
data, id_data, id_mapper = xc.structures_to_arrays([example_structures[0]])
@@ -295,6 +321,8 @@ def test_plot_and_save():
295321

296322

297323
if __name__ == "__main__":
324+
test_structures_to_arrays_zero_one()
325+
test_arrays_to_structures_zero_one()
298326
test_relax_on_decode()
299327
test_primitive_decoding()
300328
test_primitive_encoding()

0 commit comments

Comments
 (0)