@@ -86,6 +86,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
86
86
out_arr [i ][j ] = np .abs (np .vdot (final_wf , internal_wf ))** 2
87
87
88
88
self .assertAllClose (out , out_arr , atol = 1e-5 )
89
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
89
90
90
91
@parameterized .parameters ([
91
92
{
@@ -138,6 +139,7 @@ def test_correctness_without_symbols(self, n_qubits, batch_size,
138
139
out_arr [i ][j ] = np .abs (np .vdot (final_wf , internal_wf ))** 2
139
140
140
141
self .assertAllClose (out , out_arr , atol = 1e-5 )
142
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
141
143
142
144
def test_correctness_empty (self ):
143
145
"""Tests the fidelity with empty circuits."""
@@ -151,6 +153,7 @@ def test_correctness_empty(self):
151
153
other_program )
152
154
expected = np .array ([[1.0 ]], dtype = np .complex64 )
153
155
self .assertAllClose (out , expected )
156
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
154
157
155
158
qubit = cirq .GridQubit (0 , 0 )
156
159
non_empty_circuit = util .convert_to_tensor (
@@ -235,6 +238,7 @@ def test_tf_gradient_correctness_with_symbols(self, n_qubits, batch_size,
235
238
out_arr [i ][k ] += grad_fid
236
239
237
240
self .assertAllClose (out , out_arr , atol = 1e-3 )
241
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
238
242
239
243
@parameterized .parameters ([
240
244
{
@@ -272,6 +276,7 @@ def test_tf_gradient_correctness_without_symbols(self, n_qubits, batch_size,
272
276
other_programs )
273
277
out = tape .gradient (ip , symbol_values )
274
278
self .assertAllClose (out , tf .zeros_like (symbol_values ), atol = 1e-3 )
279
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
275
280
276
281
def test_correctness_no_circuit (self ):
277
282
"""Test the inner product between no circuits."""
@@ -284,6 +289,7 @@ def test_correctness_no_circuit(self):
284
289
out = fidelity_op .fidelity (empty_circuit , empty_symbols , empty_values ,
285
290
other_program )
286
291
self .assertShapeEqual (np .zeros ((0 , 0 )), out )
292
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
287
293
288
294
def test_tf_gradient_correctness_no_circuit (self ):
289
295
"""Test the inner product grad between no circuits."""
@@ -299,6 +305,7 @@ def test_tf_gradient_correctness_no_circuit(self):
299
305
empty_values , other_program )
300
306
301
307
self .assertShapeEqual (np .zeros ((0 , 0 )), out )
308
+ self .assertDTypeEqual (out , tf .float32 .as_numpy_dtype )
302
309
303
310
304
311
if __name__ == "__main__" :
0 commit comments