How to optimize a QML model using JAX and JAXopt

How to optimize a QML model using JAX and JAXopt

Published: January 18, 2024. Last updated: October 7, 2024.

Once you have set up a quantum machine learning model, data to train with, and cost function to minimize as an objective, the next step is to perform the optimization. That is, setting up a classical optimization loop to find a minimal value of your cost function.

In this example, we’ll show you how to use JAX, an autodifferentiable machine learning framework, and JAXopt, a suite of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning model.

/_images/socialthumbnail_large_How_to_optimize_QML_model_using_JAX_and_JAXopt_2024-01-16.png

Set up your model, data, and cost

Here, we will create a simple QML model for our optimization. In particular:

  • We will embed our data through a series of rotation gates.

  • We will then have an ansatz of trainable rotation gates with parameters weights; it is these values we will train to minimize our cost function.

  • We will train the QML model on data, a (5, 4) array, and optimize the model to match target predictions given by target.

import pennylane as qml
import jax
from jax import numpy as jnp
import jaxopt

jax.config.update("jax_platform_name", "cpu")

n_wires = 5 data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3 targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device("default.qubit", wires=n_wires)

@qml.qnode(dev) def circuit(data, weights): """Quantum circuit ansatz"""

# data embedding for i in range(n_wires): # data[i] will be of shape (4,); we are # taking advantage of operation vectorization here qml.RY(data[i], wires=i)

# trainable ansatz for i in range(n_wires): qml.RX(weights[i, 0], wires=i) qml.RY(weights[i, 1], wires=i) qml.RX(weights[i, 2], wires=i) qml.CNOT(wires=[i, (i + 1) % n_wires])

# we use a sum of local Z's as an observable since a # local Z would only be affected by params on that qubit. return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

def my_model(data, weights, bias): return circuit(data, weights) + bias

We will define a simple cost function that computes the overlap between model output and target data, and just-in-time (JIT) compile it:

@jax.jit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

Note that the model above is just an example for demonstration – there are important considerations that must be taken into account when performing QML research, including methods for data embedding, circuit architecture, and cost function, in order to build models that may have use. This is still an active area of research; see our demonstrations for details.

Initialize your parameters

Now, we can generate our trainable parameters weights and bias that will be used to train our QML model.

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

Plugging the trainable parameters, data, and target labels into our cost function, we can see the current loss as well as the parameter gradients:

print(loss_fn(params, data, targets))

print(jax.grad(loss_fn)(params, data, targets))
0.29232618
{'bias': Array(-0.754321, dtype=float32, weak_type=True), 'weights': Array([[-1.9507733e-01,  5.2854650e-02, -4.8925212e-01],
       [-1.9968867e-02, -5.3287148e-02,  9.2290469e-02],
       [-2.7175695e-03, -9.6455216e-05, -4.7958046e-03],
       [-6.3544422e-02,  3.6111072e-02, -2.0519713e-01],
       [-9.0263695e-02,  1.6375928e-01, -5.6426275e-01]], dtype=float32)}

Create the optimizer

We can now use JAXopt to create a gradient descent optimizer, and train our circuit.

To do so, we first create a function that returns the loss value and the gradient value during training; this allows us to track and print out the loss during training within JAXopt’s internal optimization loop.

def loss_and_grad(params, data, targets, print_training, i):
    loss_val, grad_val = jax.value_and_grad(loss_fn)(params, data, targets)

def print_fn(): jax.debug.print("Step: {i} Loss: {loss_val}", i=i, loss_val=loss_val)

# if print_training=True, print the loss every 5 steps jax.lax.cond((jnp.mod(i, 5) == 0) & print_training, print_fn, lambda: None)

return loss_val, grad_val

Note that we use a couple of JAX specific functions here:

  • jax.lax.cond instead of a Python if statement

  • jax.debug.print instead of a Python print function

These JAX compatible functions are needed because JAXopt will automatically JIT compile the optimizer update step.

opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True)
opt_state = opt.init_state(params)

for i in range(100): params, opt_state = opt.update(params, opt_state, data, targets, True, i)
Step: 0  Loss: 0.2923261821269989
Step: 5  Loss: 0.08928847312927246
Step: 10  Loss: 0.07159452140331268
Step: 15  Loss: 0.0573359839618206
Step: 20  Loss: 0.047165658324956894
Step: 25  Loss: 0.039545513689517975
Step: 30  Loss: 0.033213984221220016
Step: 35  Loss: 0.027763623744249344
Step: 40  Loss: 0.02355431765317917
Step: 45  Loss: 0.02116141840815544
Step: 50  Loss: 0.020479023456573486
Step: 55  Loss: 0.020495953038334846
Step: 60  Loss: 0.020188236609101295
Step: 65  Loss: 0.019282221794128418
Step: 70  Loss: 0.018023068085312843
Step: 75  Loss: 0.016644064337015152
Step: 80  Loss: 0.015254518948495388
Step: 85  Loss: 0.013919454999268055
Step: 90  Loss: 0.012653568759560585
Step: 95  Loss: 0.011443670839071274

Jitting the optimization loop

In the above example, we JIT compiled our cost function loss_fn (and JAXopt automatically JIT compiled the loss_and_grad function behind the scenes). However, we can also JIT compile the entire optimization loop; this means that the for-loop around optimization is not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data transfer between Python and our JIT compiled cost function with each update step.

@jax.jit
def optimization_jit(params, data, targets, print_training=False):
    opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True)
    opt_state = opt.init_state(params)

def update(i, args): params, opt_state = opt.update(*args, i) return (params, opt_state, *args[2:])

args = (params, opt_state, data, targets, print_training) (params, opt_state, _, _, _) = jax.lax.fori_loop(0, 100, update, args)

return params

Note that – similar to above – we use jax.lax.fori_loop and jax.lax.cond, rather than a standard Python for loop and if statement, to allow the control flow to be JIT compatible.

params = {"weights": weights, "bias": bias}
optimization_jit(params, data, targets, print_training=True)
Step: 0  Loss: 0.2923261821269989
Step: 5  Loss: 0.08928847312927246
Step: 10  Loss: 0.07159452140331268
Step: 15  Loss: 0.0573359839618206
Step: 20  Loss: 0.047165658324956894
Step: 25  Loss: 0.039545513689517975
Step: 30  Loss: 0.033213984221220016
Step: 35  Loss: 0.027763623744249344
Step: 40  Loss: 0.02355431765317917
Step: 45  Loss: 0.02116141840815544
Step: 50  Loss: 0.020479023456573486
Step: 55  Loss: 0.020495953038334846
Step: 60  Loss: 0.020188236609101295
Step: 65  Loss: 0.019282221794128418
Step: 70  Loss: 0.018023068085312843
Step: 75  Loss: 0.016644064337015152
Step: 80  Loss: 0.015254518948495388
Step: 85  Loss: 0.013919454999268055
Step: 90  Loss: 0.012653568759560585
Step: 95  Loss: 0.011443670839071274

{'bias': Array(-0.90738827, dtype=float32, weak_type=True), 'weights': Array([[ 1.5734389 , 1.4267787 , 0.5062411 ], [ 0.17005834, 0.83468944, 1.9624869 ], [ 1.4379824 , 1.1278558 , 2.2343845 ], [-0.2590661 , 0.5319227 , 1.3994671 ], [ 1.2112952 , 1.6135516 , 3.1225502 ]], dtype=float32)}

Appendix: Timing the two approaches

We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire optimization loop) to explore the differences in performance:

from timeit import repeat

def optimization(params, data, targets): opt = jaxopt.GradientDescent(loss_and_grad, stepsize=0.3, value_and_grad=True) opt_state = opt.init_state(params)

for i in range(100): params, opt_state = opt.update(params, opt_state, data, targets, False, i)

return params

reps = 5 num = 2

times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps) result = min(times) / num

print(f"Jitting just the cost (best of {reps}): {result} sec per loop")

times = repeat("optimization_jit(params, data, targets)", globals=globals(), number=num, repeat=reps) result = min(times) / num

print(f"Jitting the entire optimization (best of {reps}): {result} sec per loop")
Jitting just the cost (best of 5): 0.8580825305000133 sec per loop
Jitting the entire optimization (best of 5): 0.002648355499985655 sec per loop

In this example, JIT compiling the entire optimization loop is significantly more performant.

About the authors

Total running time of the script: (0 minutes 14.297 seconds)

Josh Izaac

Josh Izaac

Josh is a theoretical physicist, software tinkerer, and occasional baker. At Xanadu, he contributes to the development and growth of Xanadu’s open-source quantum software products.

Maria Schuld

Maria Schuld

Dedicated to making quantum machine learning a reality one day.