How to optimize a QML model using JAX and Optax¶
Published: January 18, 2024. Last updated: October 7, 2024.
Note
Go to the end to download the full example code.
Once you have set up a quantum machine learning model, data to train with and cost function to minimize as an objective, the next step is to perform the optimization. That is, setting up a classical optimization loop to find a minimal value of your cost function.
In this example, we’ll show you how to use JAX, an autodifferentiable machine learning framework, and Optax, a suite of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning model.

Set up your model, data, and cost¶
Here, we will create a simple QML model for our optimization. In particular:
-
We will embed our data through a series of rotation gates.
-
We will then have an ansatz of trainable rotation gates with parameters
weights
; it is these values we will train to minimize our cost function. -
We will train the QML model on
data
, a(5, 4)
array, and optimize the model to match target predictions given bytarget
.
import pennylane as qml
import jax
from jax import numpy as jnp
import optax
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])
dev = qml.device("default.qubit", wires=n_wires)
@qml.qnode(dev)
def circuit(data, weights):
"""Quantum circuit ansatz"""
# data embedding
for i in range(n_wires):
# data[i] will be of shape (4,); we are
# taking advantage of operation vectorization here
qml.RY(data[i], wires=i)
# trainable ansatz
for i in range(n_wires):
qml.RX(weights[i, 0], wires=i)
qml.RY(weights[i, 1], wires=i)
qml.RX(weights[i, 2], wires=i)
qml.CNOT(wires=[i, (i + 1) % n_wires])
# we use a sum of local Z's as an observable since a
# local Z would only be affected by params on that qubit.
return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))
def my_model(data, weights, bias):
return circuit(data, weights) + bias
We will define a simple cost function that computes the overlap between model output and target data, and just-in-time (JIT) compile it:
@jax.jit
def loss_fn(params, data, targets):
predictions = my_model(data, params["weights"], params["bias"])
loss = jnp.sum((targets - predictions) ** 2 / len(data))
return loss
Note that the model above is just an example for demonstration – there are important considerations that must be taken into account when performing QML research, including methods for data embedding, circuit architecture, and cost function, in order to build models that may have use. This is still an active area of research; see our demonstrations for details.
Initialize your parameters¶
Now, we can generate our trainable parameters weights
and bias
that will be used to train
our QML model.
weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}
Plugging the trainable parameters, data, and target labels into our cost function, we can see the current loss as well as the parameter gradients:
print(loss_fn(params, data, targets))
print(jax.grad(loss_fn)(params, data, targets))
0.2923263
{'bias': Array(-0.7543211, dtype=float32, weak_type=True), 'weights': Array([[-1.9507737e-01, 5.2854680e-02, -4.8925218e-01],
[-1.9968897e-02, -5.3287193e-02, 9.2290491e-02],
[-2.7175546e-03, -9.6470118e-05, -4.7957897e-03],
[-6.3544415e-02, 3.6111102e-02, -2.0519719e-01],
[-9.0263695e-02, 1.6375934e-01, -5.6426287e-01]], dtype=float32)}
Create the optimizer¶
We can now use Optax to create an optimizer, and train our circuit. Here, we choose the Adam optimizer, however other available optimizers may be used here.
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)
We first define our update_step
function, which needs to do a couple of things:
-
Compute the loss function (so we can track training) and the gradients (so we can apply an optimization step). We can do this in one execution via the
jax.value_and_grad
function. -
Apply the update step of our optimizer via
opt.update
-
Update the parameters via
optax.apply_updates
def update_step(opt, params, opt_state, data, targets):
loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state, loss_val
loss_history = []
for i in range(100):
params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)
if i % 5 == 0:
print(f"Step: {i} Loss: {loss_val}")
loss_history.append(loss_val)
Step: 0 Loss: 0.29232630133628845
Step: 5 Loss: 0.04476676881313324
Step: 10 Loss: 0.03190240263938904
Step: 15 Loss: 0.036237362772226334
Step: 20 Loss: 0.03370067849755287
Step: 25 Loss: 0.028724007308483124
Step: 30 Loss: 0.023011859506368637
Step: 35 Loss: 0.01871575601398945
Step: 40 Loss: 0.014776408672332764
Step: 45 Loss: 0.010427693836390972
Step: 50 Loss: 0.009645616635680199
Step: 55 Loss: 0.024109352380037308
Step: 60 Loss: 0.008082641288638115
Step: 65 Loss: 0.007608992047607899
Step: 70 Loss: 0.007097674999386072
Step: 75 Loss: 0.006783722899854183
Step: 80 Loss: 0.006901645567268133
Step: 85 Loss: 0.0065842135809361935
Step: 90 Loss: 0.006033502519130707
Step: 95 Loss: 0.004975274670869112
Jitting the optimization loop¶
In the above example, we JIT compiled our cost function loss_fn
. However, we can
also JIT compile the entire optimization loop; this means that the for-loop around optimization is
not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data
transfer between Python and our JIT compiled cost function with each update step.
# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)
@jax.jit
def update_step_jit(i, args):
params, opt_state, data, targets, print_training = args
loss_val, grads = jax.value_and_grad(loss_fn)(params, data, targets)
updates, opt_state = opt.update(grads, opt_state)
params = optax.apply_updates(params, updates)
def print_fn():
jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val)
# if print_training=True, print the loss every 5 steps
jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)
return (params, opt_state, data, targets, print_training)
@jax.jit
def optimization_jit(params, data, targets, print_training=False):
opt_state = opt.init(params)
args = (params, opt_state, data, targets, print_training)
(params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update_step_jit, args)
return params
Note that we use jax.lax.fori_loop
and jax.lax.cond
, rather than a standard Python for loop
and if statement, to allow the control flow to be JIT compatible. We also
use jax.debug.print
to allow printing to take place at function run-time,
rather than compile-time.
params = {"weights": weights, "bias": bias}
optimization_jit(params, data, targets, print_training=True)
Step: 0 Loss: 0.29232630133628845
Step: 5 Loss: 0.04476666823029518
Step: 10 Loss: 0.03190242499113083
Step: 15 Loss: 0.03623739257454872
Step: 20 Loss: 0.03370068594813347
Step: 25 Loss: 0.028724048286676407
Step: 30 Loss: 0.02301187813282013
Step: 35 Loss: 0.01871572807431221
Step: 40 Loss: 0.014776417054235935
Step: 45 Loss: 0.01042767520993948
Step: 50 Loss: 0.009645631536841393
Step: 55 Loss: 0.02410946786403656
Step: 60 Loss: 0.008082677610218525
Step: 65 Loss: 0.007609029300510883
Step: 70 Loss: 0.007097640074789524
Step: 75 Loss: 0.006783587858080864
Step: 80 Loss: 0.006901636719703674
Step: 85 Loss: 0.0065839518792927265
Step: 90 Loss: 0.0060335081070661545
Step: 95 Loss: 0.004975202493369579
{'bias': Array(-0.7529048, dtype=float32), 'weights': Array([[ 1.6309154 , 1.5501628 , 0.6721562 ],
[ 0.72661567, 0.36423036, -0.75626165],
[ 2.7838068 , 0.62710565, 3.4500601 ],
[-1.1012629 , -0.1270336 , 0.89287686],
[ 1.2723556 , 1.1062945 , 2.2205083 ]], dtype=float32)}
Appendix: Timing the two approaches¶
We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire optimization loop) to explore the differences in performance:
from timeit import repeat
def optimization(params, data, targets):
opt = optax.adam(learning_rate=0.3)
opt_state = opt.init(params)
for i in range(100):
params, opt_state, loss_val = update_step(opt, params, opt_state, data, targets)
return params
reps = 5
num = 2
times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num
print(f"Jitting just the cost (best of {reps}): {result} sec per loop")
times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num
print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop")
Jitting just the cost (best of 5): 0.3644627410000112 sec per loop
Jitting the entire optimization (best of 5): 0.0055419745000051535 sec per loop
In this example, JIT compiling the entire optimization loop is significantly more performant.
Josh Izaac
Josh is a theoretical physicist, software tinkerer, and occasional baker. At Xanadu, he contributes to the development and growth of Xanadu’s open-source quantum software products.
Maria Schuld
Dedicated to making quantum machine learning a reality one day.
Total running time of the script: (0 minutes 8.182 seconds)
Share demo