diff --git a/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.metadata.json b/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.metadata.json index 25b0e7b7f4..2447d6fcf9 100644 --- a/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.metadata.json +++ b/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.metadata.json @@ -9,7 +9,7 @@ } ], "dateOfPublication": "2024-04-26T00:00:00+00:00", - "dateOfLastModification": "2024-11-08T00:00:00+00:00", + "dateOfLastModification": "2024-11-12T00:00:00+00:00", "categories": [ "Quantum Machine Learning", "Optimization", diff --git a/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.py b/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.py index f408125fa5..779a85668d 100644 --- a/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.py +++ b/demonstrations/tutorial_how_to_quantum_just_in_time_compile_vqe_catalyst.py @@ -141,8 +141,11 @@ def cost(params): # not JAX compatible. # # Instead, we can use `Optax `__, a library designed for -# optimization using JAX, as well as the :func:`~.catalyst.grad` function, which allows us to -# differentiate through quantum just-in-time compiled workflows. +# optimization using JAX, as well as the :func:`~.catalyst.value_and_grad` function, which allows us to +# differentiate through quantum just-in-time compiled workflows while also returning the cost value. +# Here we use :func:`~.catalyst.value_and_grad` as we want to be able to print out and track our +# cost function during execution, but if this is not required the :func:`~.catalyst.grad` function +# can be used instead. # import catalyst @@ -153,23 +156,17 @@ def cost(params): @qml.qjit def update_step(i, params, opt_state): """Perform a single gradient update step""" - grads = catalyst.grad(cost)(params) + energy, grads = catalyst.value_and_grad(cost)(params) updates, opt_state = opt.update(grads, opt_state) params = optax.apply_updates(params, updates) + catalyst.debug.print("Step = {i}, Energy = {energy:.8f} Ha", i=i, energy=energy) return (params, opt_state) -loss_history = [] - opt_state = opt.init(init_params) params = init_params for i in range(10): params, opt_state = update_step(i, params, opt_state) - loss_val = cost(params) - - print(f"--- Step: {i}, Energy: {loss_val:.8f}") - - loss_history.append(loss_val) ###################################################################### # Step 4: QJIT-compile the optimization