Skip to content

Commit

Permalink
Merge branch 'master' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
actions-user committed Dec 13, 2024
2 parents 95ef8d3 + 8765fb2 commit 57eb851
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
}
],
"dateOfPublication": "2024-04-26T00:00:00+00:00",
"dateOfLastModification": "2024-11-08T00:00:00+00:00",
"dateOfLastModification": "2024-11-12T00:00:00+00:00",
"categories": [
"Quantum Machine Learning",
"Optimization",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ def cost(params):
# not JAX compatible.
#
# Instead, we can use `Optax <https://github.com/google-deepmind/optax>`__, a library designed for
# optimization using JAX, as well as the :func:`~.catalyst.grad` function, which allows us to
# differentiate through quantum just-in-time compiled workflows.
# optimization using JAX, as well as the :func:`~.catalyst.value_and_grad` function, which allows us to
# differentiate through quantum just-in-time compiled workflows while also returning the cost value.
# Here we use :func:`~.catalyst.value_and_grad` as we want to be able to print out and track our
# cost function during execution, but if this is not required the :func:`~.catalyst.grad` function
# can be used instead.
#

import catalyst
Expand All @@ -153,23 +156,17 @@ def cost(params):
@qml.qjit
def update_step(i, params, opt_state):
"""Perform a single gradient update step"""
grads = catalyst.grad(cost)(params)
energy, grads = catalyst.value_and_grad(cost)(params)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
catalyst.debug.print("Step = {i}, Energy = {energy:.8f} Ha", i=i, energy=energy)
return (params, opt_state)

loss_history = []

opt_state = opt.init(init_params)
params = init_params

for i in range(10):
params, opt_state = update_step(i, params, opt_state)
loss_val = cost(params)

print(f"--- Step: {i}, Energy: {loss_val:.8f}")

loss_history.append(loss_val)

######################################################################
# Step 4: QJIT-compile the optimization
Expand Down

0 comments on commit 57eb851

Please sign in to comment.