Skip to content

Commit

Permalink
actually do a test statement for product rule
Browse files Browse the repository at this point in the history
  • Loading branch information
Cryoris committed Jan 29, 2021
1 parent e769861 commit e83c866
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions test_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,26 @@ def test_partial_large_circuit(self, method):

@data('reference_gradients', 'iterative_gradients')
def test_product_rule(self, method):
x = Parameter('x')
x, y = Parameter('x'), Parameter('y')
circuit = QuantumCircuit(1)
circuit.rx(x, 0)
circuit.ry(x, 0)
circuit.rz(y, 0)
circuit.rx(x, 0)
circuit.h(0)
circuit.rx(y, 0)

state_in = Statevector.from_int(1, dims=(2,))

grad = StateGradient(Z, circuit, state_in, [x])
grads = getattr(grad, method)({x: 1})
parameter_binds = {x: 1, y: 2}

print(Gradient().convert(~StateFn(Z) @ StateFn(circuit), params=[x]).bind_parameters({x: 1}).eval())
grad = StateGradient(Z, circuit, state_in, [x, y])
grads = getattr(grad, method)(parameter_binds)

ref_grad = Gradient().convert(~StateFn(Z) @ StateFn(circuit), params=[x, y])
ref = ref_grad.bind_parameters(parameter_binds).eval()

np.testing.assert_array_almost_equal(grads, ref)



Expand Down

0 comments on commit e83c866

Please sign in to comment.