From 07a95903279313658e73004ff0e06fa9f08f1cd9 Mon Sep 17 00:00:00 2001 From: Antonio Martinez Date: Tue, 15 Feb 2022 01:13:07 -0500 Subject: [PATCH] Move operator expectations to QuantumInference (#179) --- qhbmlib/circuit_infer.py | 65 ++++++++------ qhbmlib/hamiltonian_infer.py | 31 ++----- qhbmlib/vqt_loss.py | 17 +--- tests/circuit_infer_test.py | 147 ++++++++++++++++++++++++++++++-- tests/hamiltonian_infer_test.py | 136 +++++++++-------------------- tests/vqt_loss_test.py | 1 + 6 files changed, 237 insertions(+), 160 deletions(-) diff --git a/qhbmlib/circuit_infer.py b/qhbmlib/circuit_infer.py index 4da88805..ec064117 100644 --- a/qhbmlib/circuit_infer.py +++ b/qhbmlib/circuit_infer.py @@ -21,6 +21,8 @@ import tensorflow_quantum as tfq from qhbmlib import circuit_model +from qhbmlib import energy_model +from qhbmlib import hamiltonian_model from qhbmlib import utils @@ -82,40 +84,55 @@ def differentiator(self): return self._differentiator def expectation(self, qnn: circuit_model.QuantumCircuit, - initial_states: tf.Tensor, operators: tf.Tensor): - """Returns the expectation values of the operators against the QNN. + initial_states: tf.Tensor, + observables: Union[tf.Tensor, hamiltonian_model.Hamiltonian]): + """Returns the expectation values of the observables against the QNN. - Args: - qnn: The parameterized quantum circuit on which to do inference. - initial_states: Shape [batch_size, num_qubits] of dtype `tf.int8`. - Each entry is an initial state for the set of qubits. For each state, - `qnn` is applied and the pure state expectation value is calculated. - operators: `tf.Tensor` of strings with shape [n_ops], result of calling - `tfq.convert_to_tensor` on a list of cirq.PauliSum, `[op1, op2, ...]`. - Will be tiled to measure `_((qnn)|initial_states[i]>)` - for each i and j. + Args: + qnn: The parameterized quantum circuit on which to do inference. + initial_states: Shape [batch_size, num_qubits] of dtype `tf.int8`. + Each entry is an initial state for the set of qubits. For each state, + `qnn` is applied and the pure state expectation value is calculated. + observables: Hermitian operators to measure. If `tf.Tensor`, strings with + shape [n_ops], result of calling `tfq.convert_to_tensor` on a list of + cirq.PauliSum, `[op1, op2, ...]`. Otherwise, a Hamiltonian. Will be + tiled to measure `_((qnn)|initial_states[i]>)` for each i and j. + + Returns: + `tf.Tensor` with shape [batch_size, n_ops] whose entries are the + unaveraged expectation values of each `operator` against each + transformed initial state. + """ + if isinstance(observables, tf.Tensor): + u = qnn + ops = observables + post_process = lambda x: x + elif isinstance(observables.energy, energy_model.PauliMixin): + u = qnn + observables.circuit_dagger + ops = observables.operator_shards + post_process = lambda y: tf.map_fn( + lambda x: tf.expand_dims( + observables.energy.operator_expectation(x), 0), y) + else: + raise NotImplementedError( + "General `BitstringEnergy` models not yet supported.") - Returns: - `tf.Tensor` with shape [batch_size, n_ops] whose entries are the - unaveraged expectation values of each `operator` against each - transformed initial state. - """ unique_states, idx, counts = utils.unique_bitstrings_with_counts( initial_states) - circuits = qnn(unique_states) + circuits = u(unique_states) num_circuits = tf.shape(circuits)[0] - num_operators = tf.shape(operators)[0] + num_ops = tf.shape(ops)[0] tiled_values = tf.tile( - tf.expand_dims(qnn.symbol_values, 0), [num_circuits, 1]) - tiled_operators = tf.tile(tf.expand_dims(operators, 0), [num_circuits, 1]) + tf.expand_dims(u.symbol_values, 0), [num_circuits, 1]) + tiled_ops = tf.tile(tf.expand_dims(ops, 0), [num_circuits, 1]) expectations = self._expectation_function( circuits, - qnn.symbol_names, + u.symbol_names, tiled_values, - tiled_operators, - tf.tile(tf.expand_dims(counts, 1), [1, num_operators]), + tiled_ops, + tf.tile(tf.expand_dims(counts, 1), [1, num_ops]), ) - return utils.expand_unique_results(expectations, idx) + return utils.expand_unique_results(post_process(expectations), idx) def sample(self, qnn: circuit_model.QuantumCircuit, initial_states: tf.Tensor, counts: tf.Tensor): diff --git a/qhbmlib/hamiltonian_infer.py b/qhbmlib/hamiltonian_infer.py index 5a6be9a1..6ac20506 100644 --- a/qhbmlib/hamiltonian_infer.py +++ b/qhbmlib/hamiltonian_infer.py @@ -14,13 +14,13 @@ # ============================================================================== """Tools for inference on quantum Hamiltonians.""" +import functools from typing import Union import tensorflow as tf from qhbmlib import circuit_infer from qhbmlib import energy_infer -from qhbmlib import energy_model from qhbmlib import hamiltonian_model from qhbmlib import utils @@ -116,7 +116,7 @@ def circuits(self, model: hamiltonian_model.Hamiltonian, num_samples: int): return states, counts def expectation(self, model: hamiltonian_model.Hamiltonian, - ops: Union[tf.Tensor, hamiltonian_model.Hamiltonian]): + observables: Union[tf.Tensor, hamiltonian_model.Hamiltonian]): """Estimates observable expectation values against the density operator. TODO(#119): add expectation and derivative equations and discussions @@ -130,28 +130,15 @@ def expectation(self, model: hamiltonian_model.Hamiltonian, Args: model: The modular Hamiltonian whose normalized exponential is the density operator against which expectation values will be estimated. - ops: The observables to measure. If `tf.Tensor`, strings with shape - [n_ops], result of calling `tfq.convert_to_tensor` on a list of - cirq.PauliSum, `[op1, op2, ...]`. Otherwise, a Hamiltonian. + obervables: Hermitian operators to measure. See docstring of + `QuantumInference.expectation` for details. Returns: `tf.Tensor` with shape [n_ops] whose entries are are the sample averaged expectation values of each entry in `ops`. """ - - def expectation_f(bitstrings): - if isinstance(ops, tf.Tensor): - return self.q_inference.expectation(model.circuit, bitstrings, ops) - elif isinstance(ops.energy, energy_model.PauliMixin): - u_dagger_u = model.circuit + ops.circuit_dagger - expectation_shards = self.q_inference.expectation( - u_dagger_u, bitstrings, ops.operator_shards) - return tf.map_fn( - lambda x: tf.expand_dims(ops.energy.operator_expectation(x), 0), - expectation_shards) - else: - raise NotImplementedError( - "General `BitstringEnergy` models not yet supported.") - - self.e_inference.infer(model.energy) - return self.e_inference.expectation(expectation_f) + return self.e_inference.expectation( + functools.partial( + self.q_inference.expectation, + model.circuit, + observables=observables)) diff --git a/qhbmlib/vqt_loss.py b/qhbmlib/vqt_loss.py index e3afdd0f..e21820c8 100644 --- a/qhbmlib/vqt_loss.py +++ b/qhbmlib/vqt_loss.py @@ -18,7 +18,6 @@ import tensorflow as tf -from qhbmlib import energy_model from qhbmlib import hamiltonian_infer from qhbmlib import hamiltonian_model @@ -47,19 +46,9 @@ def vqt(qhbm_infer: hamiltonian_infer.QHBM, # See equations B4 and B5 in appendix. TODO(#119): confirm equation number. def f_vqt(bitstrings): - if isinstance(hamiltonian, tf.Tensor): - h_expectations = tf.squeeze( - qhbm_infer.q_inference.expectation(model.circuit, bitstrings, - hamiltonian), 1) - elif isinstance(hamiltonian.energy, energy_model.PauliMixin): - u_dagger_u = model.circuit + hamiltonian.circuit_dagger - expectation_shards = qhbm_infer.q_inference.expectation( - u_dagger_u, bitstrings, hamiltonian.operator_shards) - h_expectations = hamiltonian.energy.operator_expectation( - expectation_shards) - else: - raise NotImplementedError( - "General `BitstringEnergy` hamiltonians not yet supported.") + h_expectations = tf.squeeze( + qhbm_infer.q_inference.expectation(model.circuit, bitstrings, + hamiltonian), 1) beta_h_expectations = beta * h_expectations energies = tf.stop_gradient(model.energy(bitstrings)) return beta_h_expectations - energies diff --git a/tests/circuit_infer_test.py b/tests/circuit_infer_test.py index 8e680364..ccc668b0 100644 --- a/tests/circuit_infer_test.py +++ b/tests/circuit_infer_test.py @@ -16,6 +16,9 @@ import itertools from absl import logging +from absl.testing import parameterized +import random +import string import cirq import math @@ -23,9 +26,13 @@ import tensorflow as tf import tensorflow_probability as tfp import tensorflow_quantum as tfq +from tensorflow_quantum.python import util as tfq_util from qhbmlib import circuit_infer from qhbmlib import circuit_model +from qhbmlib import circuit_model_utils +from qhbmlib import energy_model +from qhbmlib import hamiltonian_model from qhbmlib import utils from tests import test_util @@ -34,7 +41,7 @@ GRAD_ATOL = 2e-4 -class QuantumInferenceTest(tf.test.TestCase): +class QuantumInferenceTest(parameterized.TestCase, tf.test.TestCase): """Tests the QuantumInference class.""" def setUp(self): @@ -52,6 +59,12 @@ def setUp(self): minval=-5.0, maxval=5.0), name="p_qnn") + self.tf_random_seed = 10 + self.tfp_seed = tf.constant([5, 6], dtype=tf.int32) + + self.close_rtol = 1e-2 + self.not_zero_atol = 1e-3 + def test_init(self): """Confirms QuantumInference is initialized correctly.""" expected_backend = "noiseless" @@ -179,6 +192,132 @@ def test_expectation(self): self.assertAllClose( actual_grad_reduced, expected_grad_reduced, atol=GRAD_ATOL) + @test_util.eager_mode_toggle + def test_expectation_cirq(self): + """Compares library expectation values to those from Cirq.""" + # observable + num_bits = 4 + qubits = cirq.GridQubit.rect(1, num_bits) + raw_ops = [ + cirq.PauliSum.from_pauli_strings( + [cirq.PauliString(cirq.Z(q)) for q in qubits]) + ] + ops = tfq.convert_to_tensor(raw_ops) + + # unitary + batch_size = 1 + n_moments = 10 + act_fraction = 0.9 + num_symbols = 2 + symbols = set() + for _ in range(num_symbols): + symbols.add("".join(random.sample(string.ascii_letters, 10))) + symbols = sorted(list(symbols)) + raw_circuits, _ = tfq_util.random_symbol_circuit_resolver_batch( + qubits, symbols, batch_size, n_moments=n_moments, p=act_fraction) + raw_circuit = raw_circuits[0] + random_values = tf.random.uniform([len(symbols)], -1, 1, tf.float32, + self.tf_random_seed).numpy().tolist() + resolver = dict(zip(symbols, random_values)) + + # hamiltonian model and inference + circuit = circuit_model.QuantumCircuit( + tfq.convert_to_tensor([raw_circuit]), qubits, tf.constant(symbols), + [tf.Variable([resolver[s] for s in symbols])], [[]]) + circuit.build([]) + q_infer = circuit_infer.QuantumInference() + + # bitstring injectors + all_bitstrings = list(itertools.product([0, 1], repeat=num_bits)) + bitstring_circuit = circuit_model_utils.bit_circuit(qubits) + bitstring_symbols = sorted(tfq.util.get_circuit_symbols(bitstring_circuit)) + bitstring_resolvers = [ + dict(zip(bitstring_symbols, b)) for b in all_bitstrings + ] + + # calculate expected values + total_circuit = bitstring_circuit + raw_circuit + total_resolvers = [{**r, **resolver} for r in bitstring_resolvers] + raw_expectations = tf.constant([[ + cirq.Simulator().simulate_expectation_values(total_circuit, o, + r)[0].real for o in raw_ops + ] for r in total_resolvers]) + expected_expectations = tf.constant(raw_expectations) + # Check that expectations are a reasonable size + self.assertAllGreater( + tf.math.abs(expected_expectations), self.not_zero_atol) + + expectation_wrapper = tf.function(q_infer.expectation) + actual_expectations = expectation_wrapper(circuit, all_bitstrings, ops) + self.assertAllClose( + actual_expectations, expected_expectations, rtol=self.close_rtol) + + # Ensure circuit parameter update changes the expectation value. + old_circuit_weights = circuit.get_weights() + circuit.set_weights([tf.ones_like(w) for w in old_circuit_weights]) + altered_circuit_expectations = expectation_wrapper(circuit, all_bitstrings, + ops) + self.assertNotAllClose( + altered_circuit_expectations, actual_expectations, rtol=self.close_rtol) + circuit.set_weights(old_circuit_weights) + + # Check that values return to start. + reset_expectations = expectation_wrapper(circuit, all_bitstrings, ops) + self.assertAllClose(reset_expectations, actual_expectations, + self.close_rtol) + + @parameterized.parameters({ + "energy_class": energy_class, + "energy_args": energy_args, + } for energy_class, energy_args in zip( + [energy_model.BernoulliEnergy, energy_model.KOBE], [[], [2]])) + @test_util.eager_mode_toggle + def test_expectation_modular_hamiltonian(self, energy_class, energy_args): + """Confirm expectation of modular Hamiltonians works.""" + # set up the modular Hamiltonian to measure + num_bits = 3 + n_moments = 5 + act_fraction = 1.0 + qubits = cirq.GridQubit.rect(1, num_bits) + energy_h = energy_class(*([list(range(num_bits))] + energy_args)) + energy_h.build([None, num_bits]) + raw_circuit_h = cirq.testing.random_circuit(qubits, n_moments, act_fraction) + circuit_h = circuit_model.DirectQuantumCircuit(raw_circuit_h) + circuit_h.build([]) + hamiltonian_measure = hamiltonian_model.Hamiltonian(energy_h, circuit_h) + raw_shards = tfq.from_tensor(hamiltonian_measure.operator_shards) + + # set up the circuit and inference + model_raw_circuit = cirq.testing.random_circuit(qubits, n_moments, + act_fraction) + model_circuit = circuit_model.DirectQuantumCircuit(model_raw_circuit) + model_circuit.build([]) + model_infer = circuit_infer.QuantumInference() + + # bitstring injectors + all_bitstrings = list(itertools.product([0, 1], repeat=num_bits)) + bitstring_circuit = circuit_model_utils.bit_circuit(qubits) + bitstring_symbols = sorted(tfq.util.get_circuit_symbols(bitstring_circuit)) + bitstring_resolvers = [ + dict(zip(bitstring_symbols, b)) for b in all_bitstrings + ] + + # calculate expected values + total_circuit = bitstring_circuit + model_raw_circuit + raw_circuit_h**-1 + expected_expectations = tf.stack([ + tf.stack([ + hamiltonian_measure.energy.operator_expectation([ + cirq.Simulator().simulate_expectation_values( + total_circuit, o, r)[0].real for o in raw_shards + ]) + ]) for r in bitstring_resolvers + ]) + + expectation_wrapper = tf.function(model_infer.expectation) + actual_expectations = expectation_wrapper(model_circuit, all_bitstrings, + hamiltonian_measure) + self.assertAllClose(actual_expectations, expected_expectations) + @test_util.eager_mode_toggle def test_sample_basic(self): """Confirms correct sampling from identity, bit flip, and GHZ QNNs.""" @@ -239,11 +378,7 @@ def test_sample_uneven(self): test_qnn = circuit_model.DirectQuantumCircuit( cirq.Circuit(cirq.H(cirq.GridQubit(0, 0)))) test_infer = circuit_infer.QuantumInference() - - @tf.function - def sample_wrapper(qnn, bitstrings, counts): - return test_infer.sample(qnn, bitstrings, counts) - + sample_wrapper = tf.function(test_infer.sample) bitstrings = tf.constant([[0], [0]], dtype=tf.int8) _, samples_counts = sample_wrapper(test_qnn, bitstrings, counts) # QNN samples should be half 0 and half 1. diff --git a/tests/hamiltonian_infer_test.py b/tests/hamiltonian_infer_test.py index c5af60d9..e7d9ca24 100644 --- a/tests/hamiltonian_infer_test.py +++ b/tests/hamiltonian_infer_test.py @@ -26,7 +26,6 @@ from qhbmlib import circuit_infer from qhbmlib import circuit_model -from qhbmlib import circuit_model_utils from qhbmlib import energy_infer from qhbmlib import energy_model from qhbmlib import hamiltonian_model @@ -68,6 +67,8 @@ def setUp(self): self.expected_q_inference, self.expected_name) + self.tfp_seed = tf.constant([5, 1], tf.int32) + def test_init(self): """Tests QHBM initialization.""" self.assertEqual(self.actual_qhbm.e_inference, self.expected_e_inference) @@ -154,8 +155,8 @@ def test_circuit_param_update(self): self.assertNotAllEqual(actual_circuits_1, actual_circuits_2) @test_util.eager_mode_toggle - def test_expectation_cirq(self): - """Compares library expectation values to those from Cirq.""" + def test_expectation_pauli(self): + """Compares QHBM expectation value to manual expectation.""" # observable num_bits = 4 qubits = cirq.GridQubit.rect(1, num_bits) @@ -166,80 +167,49 @@ def test_expectation_cirq(self): ops = tfq.convert_to_tensor(raw_ops) # unitary - batch_size = 1 - n_moments = 10 - act_fraction = 0.9 - num_symbols = 2 - symbols = set() - for _ in range(num_symbols): - symbols.add("".join(random.sample(string.ascii_letters, 10))) - symbols = sorted(list(symbols)) - raw_circuits, raw_resolvers = tfq_util.random_symbol_circuit_resolver_batch( - qubits, symbols, batch_size, n_moments=n_moments, p=act_fraction) - raw_circuit = raw_circuits[0] - resolver = {k: raw_resolvers[0].value_of(k) for k in raw_resolvers[0]} - - # hamiltonian model and inference - seed = tf.constant([5, 6], dtype=tf.int32) - energy = energy_model.BernoulliEnergy(list(range(num_bits))) - energy.build([None, num_bits]) - circuit = circuit_model.QuantumCircuit( - tfq.convert_to_tensor([raw_circuit]), qubits, tf.constant(symbols), - [tf.Variable([resolver[s] for s in symbols])], [[]]) - circuit.build([]) - actual_hamiltonian = hamiltonian_model.Hamiltonian(energy, circuit) - e_infer = energy_infer.BernoulliEnergyInference(num_bits, self.num_samples, - seed) - q_infer = circuit_infer.QuantumInference() - actual_h_infer = hamiltonian_infer.QHBM(e_infer, q_infer) + num_layers = 3 + actual_h, actual_h_infer = test_util.get_random_hamiltonian_and_inference( + qubits, + num_layers, + "expectation_test", + self.num_samples, + ebm_seed=self.tfp_seed) # sample bitstrings - e_infer.infer(energy) - samples = e_infer.sample(self.num_samples) + samples = actual_h_infer.e_inference.sample(self.num_samples) bitstrings, _, counts = utils.unique_bitstrings_with_counts(samples) - bit_list = bitstrings.numpy().tolist() - - # bitstring injectors - bitstring_circuit = circuit_model_utils.bit_circuit(qubits) - bitstring_symbols = sorted(tfq.util.get_circuit_symbols(bitstring_circuit)) - bitstring_resolvers = [ - dict(zip(bitstring_symbols, bstr)) for bstr in bit_list - ] # calculate expected values - total_circuit = bitstring_circuit + raw_circuit - total_resolvers = [{**r, **resolver} for r in bitstring_resolvers] - raw_expectations = tf.constant([[ - cirq.Simulator().simulate_expectation_values(total_circuit, o, - r)[0].real for o in raw_ops - ] for r in total_resolvers]) + raw_expectations = actual_h_infer.q_inference.expectation( + actual_h.circuit, bitstrings, ops) expected_expectations = utils.weighted_average(counts, raw_expectations) # Check that expectations are a reasonable size self.assertAllGreater(tf.math.abs(expected_expectations), 1e-3) expectation_wrapper = tf.function(actual_h_infer.expectation) - actual_expectations = expectation_wrapper(actual_hamiltonian, ops) + actual_expectations = expectation_wrapper(actual_h, ops) self.assertAllClose(actual_expectations, expected_expectations, rtol=1e-6) # Ensure energy parameter update changes the expectation value. - old_energy_weights = energy.get_weights() - energy.set_weights([tf.ones_like(w) for w in old_energy_weights]) - altered_energy_expectations = actual_h_infer.expectation( - actual_hamiltonian, ops) + old_energy_weights = actual_h.energy.get_weights() + actual_h.energy.set_weights([tf.ones_like(w) for w in old_energy_weights]) + actual_h_infer.e_inference.infer(actual_h.energy) + altered_energy_expectations = actual_h_infer.expectation(actual_h, ops) self.assertNotAllClose( altered_energy_expectations, actual_expectations, rtol=1e-5) - energy.set_weights(old_energy_weights) + actual_h.energy.set_weights(old_energy_weights) + actual_h_infer.e_inference.infer(actual_h.energy) # Ensure circuit parameter update changes the expectation value. - old_circuit_weights = circuit.get_weights() - circuit.set_weights([tf.ones_like(w) for w in old_circuit_weights]) - altered_circuit_expectations = expectation_wrapper(actual_hamiltonian, ops) + old_circuit_weights = actual_h.circuit.get_weights() + actual_h.circuit.set_weights([tf.ones_like(w) for w in old_circuit_weights]) + altered_circuit_expectations = expectation_wrapper(actual_h, ops) self.assertNotAllClose( altered_circuit_expectations, actual_expectations, rtol=1e-5) - circuit.set_weights(old_circuit_weights) + actual_h.circuit.set_weights(old_circuit_weights) # Check that values return to start. - reset_expectations = expectation_wrapper(actual_hamiltonian, ops) + reset_expectations = expectation_wrapper(actual_h, ops) self.assertAllClose(reset_expectations, actual_expectations, rtol=1e-6) @parameterized.parameters({ @@ -261,51 +231,29 @@ def test_expectation_modular_hamiltonian(self, energy_class, energy_args): circuit_h = circuit_model.DirectQuantumCircuit(raw_circuit_h) circuit_h.build([]) hamiltonian_measure = hamiltonian_model.Hamiltonian(energy_h, circuit_h) - raw_shards = tfq.from_tensor(hamiltonian_measure.operator_shards) - - # hamiltonian model and inference - seed = tf.constant([5, 6], dtype=tf.int32) - model_energy = energy_model.BernoulliEnergy(list(range(num_bits))) - model_energy.build([None, num_bits]) - model_raw_circuit = cirq.testing.random_circuit(qubits, n_moments, - act_fraction) - model_circuit = circuit_model.DirectQuantumCircuit(model_raw_circuit) - model_circuit.build([]) - model_hamiltonian = hamiltonian_model.Hamiltonian(model_energy, - model_circuit) - e_infer = energy_infer.BernoulliEnergyInference(num_bits, self.num_samples, - seed) - q_infer = circuit_infer.QuantumInference() - model_h_infer = hamiltonian_infer.QHBM(e_infer, q_infer) + + # unitary + num_layers = 3 + actual_h, actual_h_infer = test_util.get_random_hamiltonian_and_inference( + qubits, + num_layers, + "expectation_test", + self.num_samples, + ebm_seed=self.tfp_seed) # sample bitstrings - e_infer.infer(model_energy) - samples = e_infer.sample(self.num_samples) + samples = actual_h_infer.e_inference.sample(self.num_samples) bitstrings, _, counts = utils.unique_bitstrings_with_counts(samples) - bit_list = bitstrings.numpy().tolist() - - # bitstring injectors - bitstring_circuit = circuit_model_utils.bit_circuit(qubits) - bitstring_symbols = sorted(tfq.util.get_circuit_symbols(bitstring_circuit)) - bitstring_resolvers = [ - dict(zip(bitstring_symbols, bstr)) for bstr in bit_list - ] # calculate expected values - total_circuit = bitstring_circuit + model_raw_circuit + raw_circuit_h**-1 - raw_expectations = tf.stack([ - tf.stack([ - hamiltonian_measure.energy.operator_expectation([ - cirq.Simulator().simulate_expectation_values( - total_circuit, o, r)[0].real for o in raw_shards - ]) - ]) for r in bitstring_resolvers - ]) + raw_expectations = actual_h_infer.q_inference.expectation( + actual_h.circuit, bitstrings, hamiltonian_measure) expected_expectations = utils.weighted_average(counts, raw_expectations) + # Check that expectations are a reasonable size + self.assertAllGreater(tf.math.abs(expected_expectations), 1e-3) - expectation_wrapper = tf.function(model_h_infer.expectation) - actual_expectations = expectation_wrapper(model_hamiltonian, - hamiltonian_measure) + expectation_wrapper = tf.function(actual_h_infer.expectation) + actual_expectations = expectation_wrapper(actual_h, hamiltonian_measure) self.assertAllClose(actual_expectations, expected_expectations) diff --git a/tests/vqt_loss_test.py b/tests/vqt_loss_test.py index d2825ed8..ae44679f 100644 --- a/tests/vqt_loss_test.py +++ b/tests/vqt_loss_test.py @@ -209,6 +209,7 @@ def test_loss_value_x_rot(self): # Inference definition e_infer = energy_infer.BernoulliEnergyInference( num_qubits, self.num_samples, initial_seed=self.tfp_seed) + e_infer.infer(energy) q_infer = circuit_infer.QuantumInference() qhbm_infer = hamiltonian_infer.QHBM(e_infer, q_infer)