4
4
5
5
from warnings import warn
6
6
7
+ import numpy as np
7
8
import plotly .express as px
8
9
from numpy .testing import assert_allclose , assert_array_equal , assert_equal
9
10
from pymatgen .analysis .structure_matcher import ElementComparator , StructureMatcher
@@ -125,6 +126,21 @@ def test_structures_to_arrays_single():
125
126
return data
126
127
127
128
129
+ def test_structures_to_arrays_zero_one ():
130
+ xc = XtalConverter (relax_on_decode = False )
131
+ data , _ , _ = xc .structures_to_arrays (example_structures , rgb_scaling = False )
132
+
133
+ if np .min (data ) < 0.0 :
134
+ raise ValueError (
135
+ f"minimum is less than 0 when rgb_output=False: { np .min (data )} "
136
+ )
137
+ if np .max (data ) > 1.0 :
138
+ raise ValueError (
139
+ f"maximum is greater than 1 when rgb_output=False: { np .max (data )} "
140
+ )
141
+ return data
142
+
143
+
128
144
def test_arrays_to_structures ():
129
145
xc = XtalConverter (relax_on_decode = False )
130
146
data , id_data , id_mapper = xc .structures_to_arrays (example_structures )
@@ -133,6 +149,16 @@ def test_arrays_to_structures():
133
149
return structures
134
150
135
151
152
+ def test_arrays_to_structures_zero_one ():
153
+ xc = XtalConverter (relax_on_decode = False )
154
+ data , id_data , id_mapper = xc .structures_to_arrays (
155
+ example_structures , rgb_scaling = False
156
+ )
157
+ structures = xc .arrays_to_structures (data , id_data , id_mapper , rgb_scaling = False )
158
+ assert_structures_approximate_match (example_structures , structures )
159
+ return structures
160
+
161
+
136
162
def test_arrays_to_structures_single ():
137
163
xc = XtalConverter (relax_on_decode = False )
138
164
data , id_data , id_mapper = xc .structures_to_arrays ([example_structures [0 ]])
@@ -295,6 +321,8 @@ def test_plot_and_save():
295
321
296
322
297
323
if __name__ == "__main__" :
324
+ test_structures_to_arrays_zero_one ()
325
+ test_arrays_to_structures_zero_one ()
298
326
test_relax_on_decode ()
299
327
test_primitive_decoding ()
300
328
test_primitive_encoding ()
0 commit comments