PennyLane
Install
Install

Related materials

  • Related contentHow to optimize a QML model using JAX and Optax
  • Related contentUsing JAX with PennyLane
  • Related contentVariational classifier

Contents

  1. Set up your model, data, and cost
  2. Initialize your parameters
  3. Create the optimizer
  4. JIT-compiling the optimization
  5. Timing the optimization
  6. About the author

Downloads

  • Download Python script
  • Download Notebook
  • View on GitHub
  1. Demos/
  2. Quantum Machine Learning/
  3. How to optimize a QML model using Catalyst and quantum just-in-time (QJIT) compilation

How to optimize a QML model using Catalyst and quantum just-in-time (QJIT) compilation

Josh Izaac

Josh Izaac

Published: April 25, 2024. Last updated: September 09, 2025.

Once you have set up your quantum machine learning model (which typically includes deciding on your circuit architecture/ansatz, determining how you embed or integrate your data, and creating your cost function to minimize a quantity of interest), the next step is optimization. That is, setting up a classical optimization loop to find a minimal value of your cost function.

In this example, we’ll show you how to use JAX, an autodifferentiable machine learning framework, and Optax, a suite of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning model which has been quantum just-in-time compiled using the qjit() decorator and Catalyst.

demos/_static/demo_thumbnails/opengraph_demo_thumbnails/OGthumbnail_large_how-to-optimize-qjit-optax_2024-04-23.png

Set up your model, data, and cost

Here, we will create a simple QML model for our optimization. In particular:

  • We will embed our data through a series of rotation gates.

  • We will then have an ansatz of trainable rotation gates with parameters weights; it is these values we will train to minimize our cost function.

  • We will train the QML model on data, a (5, 4) array, and optimize the model to match target predictions given by target.

import pennylane as qml
from jax import numpy as jnp
import optax
import catalyst

n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device("lightning.qubit", wires=n_wires)

@qml.qnode(dev)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    @qml.for_loop(0, n_wires, 1)
    def data_embedding(i):
        qml.RY(data[i], wires=i)

    data_embedding()

    @qml.for_loop(0, n_wires, 1)
    def ansatz(i):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])

    ansatz()

    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

The catalyst.vmap() function allows us to specify that the first argument to circuit (data) contains a batch dimension. In this example, the batch dimension is the second axis (axis 1).

circuit = qml.qjit(catalyst.vmap(circuit, in_axes=(1, None)))

We will define a simple cost function that computes the overlap between model output and target data:

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

@qml.qjit
def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

Note that the model above is just an example for demonstration – there are important considerations that must be taken into account when performing QML research, including methods for data embedding, circuit architecture, and cost function, in order to build models that may have use. This is still an active area of research; see our demonstrations for details.

Initialize your parameters

Now, we can generate our trainable parameters weights and bias that will be used to train our QML model.

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

Plugging the trainable parameters, data, and target labels into our cost function, we can see the current loss as well as the parameter gradients:

loss_fn(params, data, targets)

print(qml.qjit(catalyst.grad(loss_fn, method="fd"))(params, data, targets))
{'bias': Array(-0.75432067, dtype=float64), 'weights': Array([[-1.95077271e-01,  5.28546590e-02, -4.89252073e-01],
       [-1.99687789e-02, -5.32871558e-02,  9.22904869e-02],
       [-2.71755507e-03, -9.64672786e-05, -4.79570827e-03],
       [-6.35443870e-02,  3.61110009e-02, -2.05196876e-01],
       [-9.02635405e-02,  1.63759364e-01, -5.64262612e-01]],      dtype=float64)}

Create the optimizer

We can now use Optax to create an Adam optimizer, and train our circuit.

We first define our update_step function, which needs to do a couple of things:

  • Compute the gradients of the loss function. We can do this via the catalyst.grad() function.

  • Apply the update step of our optimizer via opt.update

  • Update the parameters via optax.apply_updates

opt = optax.adam(learning_rate=0.3)

@qml.qjit
def update_step(i, args):
    params, opt_state, data, targets = args

    grads = catalyst.grad(loss_fn, method="fd")(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return (params, opt_state, data, targets)

loss_history = []

opt_state = opt.init(params)

for i in range(100):
    params, opt_state, _, _ = update_step(i, (params, opt_state, data, targets))
    loss_val = loss_fn(params, data, targets)

    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

    loss_history.append(loss_val)
Step: 0 Loss: 0.27303537436157654
Step: 5 Loss: 0.032559241520439784
Step: 10 Loss: 0.029282022870855358
Step: 15 Loss: 0.0333786497016208
Step: 20 Loss: 0.03123631419408755
Step: 25 Loss: 0.027191477959160448
Step: 30 Loss: 0.022688381945923933
Step: 35 Loss: 0.01816272685154787
Step: 40 Loss: 0.014789693833189992
Step: 45 Loss: 0.011206958607269252
Step: 50 Loss: 0.0094097940342296
Step: 55 Loss: 0.017898297254010914
Step: 60 Loss: 0.01286131622894956
Step: 65 Loss: 0.009916040473753393
Step: 70 Loss: 0.008611679289945632
Step: 75 Loss: 0.006585517664039411
Step: 80 Loss: 0.006778109386158929
Step: 85 Loss: 0.006043700205345529
Step: 90 Loss: 0.006139649744176761
Step: 95 Loss: 0.004989541380544226

JIT-compiling the optimization

In the above example, we just-in-time (JIT) compiled our cost function loss_fn. However, we can also JIT compile the entire optimization loop; this means that the for-loop around optimization is not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data transfer between Python and our JIT compiled cost function with each update step.

params = {"weights": weights, "bias": bias}

@qml.qjit
def optimization(params, data, targets):
    opt_state = opt.init(params)
    args = (params, opt_state, data, targets)
    (params, opt_state, _, _) = qml.for_loop(0, 100, 1)(update_step)(args)
    return params

Note that we use for_loop() rather than a standard Python for loop, to allow the control flow to be JIT compatible.

final_params = optimization(params, data, targets)

print(final_params)
{'bias': Array(-0.75292884, dtype=float64), 'weights': Array([[ 1.63087003,  1.55018968,  0.67212609],
       [ 0.72660627,  0.36422545, -0.75624708],
       [ 2.78387471,  0.62720991,  3.44996406],
       [-1.10119513, -0.12679492,  0.89283764],
       [ 1.27236318,  1.1063112 ,  2.22051429]], dtype=float64)}

Timing the optimization

We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire optimization loop) to explore the differences in performance:

from timeit import repeat

opt = optax.adam(learning_rate=0.3)

def optimization_noqjit(params):
    opt_state = opt.init(params)

    for i in range(100):
        params, opt_state, _, _ = update_step(i, (params, opt_state, data, targets))

    return params

reps = 5
num = 2

times = repeat("optimization_noqjit(params)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Quantum jitting just the cost (best of {reps}): {result} sec per loop")

times = repeat("optimization(params, data, targets)", globals=globals(), number=num, repeat=reps)
result = min(times) / num

print(f"Quantum jitting the entire optimization (best of {reps}): {result} sec per loop")
Quantum jitting just the cost (best of 5): 0.7091746075000174 sec per loop
Quantum jitting the entire optimization (best of 5): 0.4395241874999556 sec per loop

About the author

Josh Izaac
Josh Izaac

Josh Izaac

Josh is a theoretical physicist, software tinkerer, and occasional baker. At Xanadu, he contributes to the development and growth of Xanadu’s open-source quantum software products.

Total running time of the script: (0 minutes 17.579 seconds)

Share demo

Ask a question on the forum

Related Demos

How to optimize a QML model using JAX and Optax

Using JAX with PennyLane

Variational classifier

How to optimize a QML model using JAX and JAXopt

How to use Catalyst with Lightning-GPU

Multidimensional regression with a variational quantum circuit

How to quantum just-in-time compile VQE with Catalyst

Quantum gradients with backpropagation

Digital zero-noise extrapolation with Catalyst

Post Variational Quantum Neural Networks

PennyLane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Built by researchers, for research. Created with ❤️ by Xanadu.

Research

  • Research
  • Performance
  • Hardware & Simulators
  • Demos
  • Quantum Compilation
  • Quantum Datasets

Education

  • Teach
  • Learn
  • Codebook
  • Coding Challenges
  • Videos
  • Glossary

Software

  • Install PennyLane
  • Features
  • Documentation
  • Catalyst Compilation Docs
  • Development Guide
  • API
  • GitHub
Stay updated with our newsletter

© Copyright 2025 | Xanadu | All rights reserved

TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

Privacy Policy|Terms of Service|Cookie Policy|Code of Conduct