Skip to content

Commit 68e0ec3

Browse files
Ensure gradient of tf.math.fidelity remains float32 when autographed. (#596)
1 parent 534f65d commit 68e0ec3

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

tensorflow_quantum/core/ops/math_ops/fidelity_op.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919

2020
@tf.function
21+
@tf.custom_gradient
2122
def fidelity(programs, symbol_names, symbol_values, other_programs):
2223
"""Calculate the fidelity between circuits.
2324
@@ -78,7 +79,16 @@ def fidelity(programs, symbol_names, symbol_values, other_programs):
7879
to the fidelity of `programs[i]` with `symbol_values[i]`
7980
resolved in and `other_programs[i][j]`.
8081
"""
81-
ip = inner_product_op.inner_product(programs, symbol_names,
82-
tf.cast(symbol_values, tf.float32),
82+
f32_vals = tf.cast(symbol_values, tf.float32)
83+
ip = inner_product_op.inner_product(programs, symbol_names, f32_vals,
8384
other_programs)
84-
return tf.math.abs(ip)**2
85+
86+
def grad(dy):
87+
ret_zero = tf.equal(tf.size(symbol_names), 0)
88+
inner_prod_grad = tf.cond(
89+
ret_zero, lambda: tf.zeros_like(symbol_values, dtype=tf.float32),
90+
lambda: tf.math.real(2. * ip * inner_product_op._inner_product_grad(
91+
programs, symbol_names, symbol_values, other_programs, dy)))
92+
return [None, None, inner_prod_grad, None]
93+
94+
return tf.math.abs(ip)**2, grad

tensorflow_quantum/core/ops/math_ops/fidelity_op_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def test_correctness_with_symbols(self, n_qubits, batch_size,
8686
out_arr[i][j] = np.abs(np.vdot(final_wf, internal_wf))**2
8787

8888
self.assertAllClose(out, out_arr, atol=1e-5)
89+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
8990

9091
@parameterized.parameters([
9192
{
@@ -138,6 +139,7 @@ def test_correctness_without_symbols(self, n_qubits, batch_size,
138139
out_arr[i][j] = np.abs(np.vdot(final_wf, internal_wf))**2
139140

140141
self.assertAllClose(out, out_arr, atol=1e-5)
142+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
141143

142144
def test_correctness_empty(self):
143145
"""Tests the fidelity with empty circuits."""
@@ -151,6 +153,7 @@ def test_correctness_empty(self):
151153
other_program)
152154
expected = np.array([[1.0]], dtype=np.complex64)
153155
self.assertAllClose(out, expected)
156+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
154157

155158
qubit = cirq.GridQubit(0, 0)
156159
non_empty_circuit = util.convert_to_tensor(
@@ -235,6 +238,7 @@ def test_tf_gradient_correctness_with_symbols(self, n_qubits, batch_size,
235238
out_arr[i][k] += grad_fid
236239

237240
self.assertAllClose(out, out_arr, atol=1e-3)
241+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
238242

239243
@parameterized.parameters([
240244
{
@@ -272,6 +276,7 @@ def test_tf_gradient_correctness_without_symbols(self, n_qubits, batch_size,
272276
other_programs)
273277
out = tape.gradient(ip, symbol_values)
274278
self.assertAllClose(out, tf.zeros_like(symbol_values), atol=1e-3)
279+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
275280

276281
def test_correctness_no_circuit(self):
277282
"""Test the inner product between no circuits."""
@@ -284,6 +289,7 @@ def test_correctness_no_circuit(self):
284289
out = fidelity_op.fidelity(empty_circuit, empty_symbols, empty_values,
285290
other_program)
286291
self.assertShapeEqual(np.zeros((0, 0)), out)
292+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
287293

288294
def test_tf_gradient_correctness_no_circuit(self):
289295
"""Test the inner product grad between no circuits."""
@@ -299,6 +305,7 @@ def test_tf_gradient_correctness_no_circuit(self):
299305
empty_values, other_program)
300306

301307
self.assertShapeEqual(np.zeros((0, 0)), out)
308+
self.assertDTypeEqual(out, tf.float32.as_numpy_dtype)
302309

303310

304311
if __name__ == "__main__":

tensorflow_quantum/core/ops/math_ops/tfq_inner_product_grad.cc

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,18 @@ REGISTER_OP("TfqInnerProductGrad")
479479
c->Dim(programs_shape, 0);
480480
tensorflow::shape_inference::DimensionHandle output_cols =
481481
c->Dim(symbol_names_shape, 0);
482-
std::vector<tensorflow::shape_inference::DimensionHandle> dims = {
483-
output_rows, output_cols};
484-
c->set_output(0, c->MakeShape(dims));
482+
483+
// Use kUnknownDim instead to prevent shape inference from breaking
484+
// @tf.custom_gradient code in fidelity_op.py. The grad function has
485+
// an implicit data dependency on `sybmol_names` that shape infrence
486+
// can't (and shouldn't) see. Not specifying shape prevents this break.
487+
// std::vector<tensorflow::shape_inference::DimensionHandle> dims = {
488+
// output_rows,
489+
// tensorflow::shape_inference::InferenceContext::kUnknownDim};
490+
c->set_output(
491+
0, c->MakeShape(
492+
{output_rows,
493+
tensorflow::shape_inference::InferenceContext::kUnknownDim}));
485494

486495
return tensorflow::Status::OK();
487496
});

0 commit comments

Comments
 (0)