PennyLane
Install
Install

Contents

  1. 1. Quantum image states
  2. 2. Low depth image circuits
  3. Downloading the quantum image dataset
  4. Reconstructing images from quantum states
    1. 3. Quantum classifiers
  5. Preparing the training / validation split
  6. Training setup and loop
  7. Conclusion
    1. References
    2. Appendix
  8. About the authors

Downloads

  • Download Python script
  • Download Notebook
  • View on GitHub
  1. Demos/
  2. Quantum Machine Learning/
  3. Loading classical data with low-depth circuits

Loading classical data with low-depth circuits

Florian Kiwit

Florian Kiwit

Bernhard Jobst

Bernhard Jobst

Carlos Riofrío

Carlos Riofrío

Published: September 20, 2025. Last updated: September 20, 2025.

The encoding of arbitrary classical data into quantum states usually comes at a high computational cost, either in terms of qubits or gate count. However, real-world data typically exhibits some inherent structure (such as image data) which can be leveraged to load them with a much smaller cost on a quantum computer. The paper “Typical Machine Learning Datasets as Low‑Depth Quantum Circuits” (2025) [1] develops an efficient algorithm for finding low-depth quantum circuits to load classical image data as quantum states.

This demo gives an introduction to the paper “Typical Machine Learning Datasets as Low‑Depth Quantum Circuits” (2025). We will discuss the following three steps: 1) Quantum image states, 2) Low-depth image circuits, 3) Training a small variational‑quantum‑circuit (VQC) classifier on the dataset.

1. Quantum image states

Images with \(2^n\) pixels are mapped to states of the form \(\left| \psi(\mathbf x) \right> = \frac{1}{\sqrt{2^{n}}}\sum_{j=0}^{2^{n}-1} \left| c(x_j) \right> \otimes \left| j \right>\), where the address register \(\left| j\right>\) holds the pixel position (\(n\) qubits), and additional color qubits \(\left| c(x_j)\right>\) encode the pixel intensities. For grayscale images, we use the flexible representation of quantum images (FRQI) [2,3] as an encoding. In this case, the data value \({x}_j\) of each pixel is just a single number corresponding to the grayscale value of that pixel. We can encode this information in the \(z\)-polarization of an additional color qubit as \(\left|c({x}_j)\right> = \cos({\textstyle\frac{\pi}{2}} {x}_j) \left| 0 \right> + \sin({\textstyle\frac{\pi}{2}} {x}_j) \left| 1 \right>\), with the pixel value normalized to \({x}_j \in [0,1]\). Thus, a grayscale image with \(2^n\) pixels is encoded into a quantum state with \(n+1\) qubits.

For color images, the multi-channel representation of quantum images (MCRQI) [4,5] can be used. Python implementations of the MCRQI encoding and decoding are provided at the end of this demo and are discussed in Ref. [1].

from pennylane import numpy as np

# Grayscale encodings and decodings


def FRQI_encoding(images):
    """
    Input : (batchsize, N, N) ndarray
        A batch of square arrays representing grayscale images.
    Returns : (batchsize, 2, N**2) ndarray
        A batch of quantum states encoding the grayscale images using the FRQI.
    """
    # get image size and number of qubits
    batchsize, N, _ = images.shape
    n = 2 * int(np.log2(N))
    # reorder pixels hierarchically
    states = np.reshape(images, (batchsize, *(2,) * n))
    states = np.transpose(states, [0] + [ax + 1 for q in range(n // 2) for ax in (q, q + n // 2)])
    # FRQI encoding by stacking cos and sin components
    states = np.stack([np.cos(np.pi / 2 * states), np.sin(np.pi / 2 * states)], axis=1)
    # normalize and reshape
    states = np.reshape(states, (batchsize, 2, N**2)) / N
    return states


def FRQI_decoding(states):
    """
    Input : (batchsize, 2, N**2) ndarray
        A batch of quantum states encoding grayscale images using the FRQI.
    Returns : (batchsize, N, N) ndarray
        A batch of square arrays representing the grayscale images.
    """
    # get batchsize and number of qubits
    batchsize = states.shape[0]
    states = np.reshape(states, (batchsize, 2, -1))
    n = int(np.log2(states.shape[2]))
    # invert FRQI encoding to get pixel values
    images = np.arccos(states[:, 0] ** 2 * 2**n - states[:, 1] ** 2 * 2**n) / np.pi
    # undo hierarchical ordering
    images = np.reshape(images, (batchsize, *(2,) * n))
    images = np.transpose(images, [0, *range(1, n, 2), *range(2, n + 1, 2)])
    # reshape to square image
    images = np.reshape(images, (batchsize, 2 ** (n // 2), 2 ** (n // 2)))
    return images

2. Low depth image circuits

In general, the complexity of preparing the resulting state exactly scales exponentially with the number of qubits. Known constructions (without auxiliary qubits) use \(\mathcal{O}(4^n)\) gates [2,3]. However, encoding typical images this way leads to lowly entangled quantum states that are well approximated by tensor-network states such as matrix-product states (MPSs) [6] whose bond dimension \(\chi\) does not need to scale with the image resolution. Thus, preparing the state approximately with a small error is possible with a number of gates that scales only as \(\mathcal{O}(\chi^2n)\), i.e., linearly with the number of qubits. While the cost of the classical preprocessing may be similar to the exact state preparation, the resulting quantum circuits are exponentially more efficient.

The following illustration shows the quantum circuits inspired by MPSs. The left side shows a circuit with a staircase pattern with two layers (represented in turquoise and pink), where two-qubit gates are applied sequentially, corresponding to a right-canonical MPS. The right side shows the proposed circuit architecture corresponding to an MPS in mixed canonical form. By effectively shifting the gates with the dashed outlines to the right, the gates are applied sequentially outward starting from the center. This reduces the circuit depth while maintaining its expressivity.

# .. figure:: /_static/demonstration_assets/low_depth_circuits_mnist/circuit.png
#    :align: center
#    :width: 80 %
#    :alt: Illustration of quantum circuits inspired by MPSs
#
#    Illustration of quantum circuits inspired by MPSs

Downloading the quantum image dataset

The dataset configuration sets the name as 'low-depth-mnist' and constructs the dataset path as datasets/low-depth-mnist/low-depth-mnist.h5. For dataset loading, if the file exists locally, it is loaded using qml.data.Dataset.open. Otherwise, the dataset is downloaded from the PennyLane data repository via qml.data.load, note that the dataset size is approximately 1 GB.

Attribute

Description

exact_state

The exact state that the corresponding circuit should prepare

labels

The correct labels classifying the corresponding images

circuit_layout_d4

The layout of the depth 4 circuit

circuit_layout_d8

The layout of the depth 8 circuit

params_d4

Parameters for the depth 4 circuit

params_d8

Parameters for the depth 8 circuit

fidelities_d4

Fidelities between the depth 4 state and the exact state

fidelities_d8

Fidelities between the depth 8 state and the exact state

import os
import jax
import pennylane as qml
from tqdm import tqdm

# JAX supports the single-precision numbers by default. The following line enables double-precision.
jax.config.update("jax_enable_x64", True)
# Set JAX to use CPU, simply set this to 'gpu' or 'tpu' to use those devices.
jax.config.update("jax_platform_name", "cpu")

# Here you can choose the dataset and the encoding depth, depth 4 and depth 8 are available
DATASET_NAME = "low-depth-mnist"

dataset_path = f"datasets/{DATASET_NAME}.h5"

# Load the dataset if already downloaded
if os.path.exists(dataset_path):
    dataset_params = qml.data.Dataset.open(dataset_path)
else:
    # Download the dataset (~ 1 GB)
    [dataset_params] = qml.data.load(DATASET_NAME)

In the following cell, we define the get_circuit function that creates a quantum circuit based on the provided layout. The circuit_layout is an attribute of the dataset that specifies the sequence of quantum gates and their target qubits, which depends on the number of qubits and circuit depth. After defining the circuit function, we extract the relevant data for binary classification (digits 0 and 1 only) and compute the quantum states by executing the circuits with their corresponding parameters. These generated states will be used later for training the quantum classifier.

TARGET_LABELS = [0, 1]


def get_circuit(circuit_layout):
    """
    Create a quantum circuit with a given layout for preparing quantum states.
    The circuit only contains RY rotation gates and CNOT gates, designed for efficient
    state preparation with low circuit depth.

    :param circuit_layout: List of tuples containing gate types ('RY' or 'CNOT') and their target wires.
    :return circuit: A JAX-compiled quantum circuit function that takes parameters and returns the quantum state.
    """
    dev = qml.device("default.qubit", wires=11)

    @jax.jit
    @qml.qnode(dev)
    def circuit(params):
        counter = 0
        for gate, wire in circuit_layout:

            if gate == "RY":
                qml.RY(params[counter], wire)
                counter += 1

            elif gate == "CNOT":
                qml.CNOT(wire)

        return qml.state()

    return circuit


# Unpack the dataset attributes, in this demo only digits 0 and 1 will be used
labels = np.asarray(dataset_params.labels)
selection = np.isin(labels, TARGET_LABELS)
labels_01 = labels[selection]
exact_state = np.asarray(dataset_params.exact_state)[selection]

circuit_layout = dataset_params.circuit_layout_d4
circuit = get_circuit(circuit_layout)
params_01 = np.asarray(dataset_params.params_d4)[selection]
states_01 = np.asarray([circuit(params) for params in tqdm(params_01, desc="States for depth 4")])
fidelities_01 = np.asarray(dataset_params.fidelities_d4)[selection]
States for depth 4:   0%|          | 0/14708 [00:00<?, ?it/s]
States for depth 4:   0%|          | 1/14708 [00:02<11:07:22,  2.72s/it]
States for depth 4:   1%|          | 123/14708 [00:02<03:58, 61.14it/s]
States for depth 4:   1%|▏         | 215/14708 [00:02<02:02, 118.15it/s]
States for depth 4:   2%|▏         | 307/14708 [00:03<01:16, 187.57it/s]
States for depth 4:   3%|▎         | 399/14708 [00:03<00:53, 267.97it/s]
States for depth 4:   3%|▎         | 491/14708 [00:03<00:39, 356.07it/s]
States for depth 4:   4%|▍         | 582/14708 [00:03<00:31, 445.42it/s]
States for depth 4:   5%|▍         | 674/14708 [00:03<00:26, 534.39it/s]
States for depth 4:   5%|▌         | 765/14708 [00:03<00:22, 613.60it/s]
States for depth 4:   6%|▌         | 856/14708 [00:03<00:20, 681.18it/s]
States for depth 4:   6%|▋         | 947/14708 [00:03<00:18, 737.60it/s]
States for depth 4:   7%|▋         | 1039/14708 [00:03<00:17, 783.63it/s]
States for depth 4:   8%|▊         | 1131/14708 [00:03<00:16, 818.72it/s]
States for depth 4:   8%|▊         | 1223/14708 [00:04<00:15, 845.56it/s]
States for depth 4:   9%|▉         | 1315/14708 [00:04<00:15, 865.62it/s]
States for depth 4:  10%|▉         | 1407/14708 [00:04<00:15, 877.94it/s]
States for depth 4:  10%|█         | 1499/14708 [00:04<00:14, 889.37it/s]
States for depth 4:  11%|█         | 1591/14708 [00:04<00:14, 896.19it/s]
States for depth 4:  11%|█▏        | 1683/14708 [00:04<00:14, 899.12it/s]
States for depth 4:  12%|█▏        | 1775/14708 [00:04<00:14, 902.14it/s]
States for depth 4:  13%|█▎        | 1867/14708 [00:04<00:14, 905.00it/s]
States for depth 4:  13%|█▎        | 1959/14708 [00:04<00:14, 905.00it/s]
States for depth 4:  14%|█▍        | 2050/14708 [00:04<00:14, 903.52it/s]
States for depth 4:  15%|█▍        | 2142/14708 [00:05<00:13, 905.70it/s]
States for depth 4:  15%|█▌        | 2233/14708 [00:05<00:13, 905.79it/s]
States for depth 4:  16%|█▌        | 2325/14708 [00:05<00:13, 907.37it/s]
States for depth 4:  16%|█▋        | 2416/14708 [00:05<00:13, 908.04it/s]
States for depth 4:  17%|█▋        | 2508/14708 [00:05<00:13, 910.68it/s]
States for depth 4:  18%|█▊        | 2600/14708 [00:05<00:13, 910.21it/s]
States for depth 4:  18%|█▊        | 2692/14708 [00:05<00:13, 911.57it/s]
States for depth 4:  19%|█▉        | 2784/14708 [00:05<00:13, 909.30it/s]
States for depth 4:  20%|█▉        | 2876/14708 [00:05<00:13, 909.69it/s]
States for depth 4:  20%|██        | 2967/14708 [00:05<00:12, 909.41it/s]
States for depth 4:  21%|██        | 3059/14708 [00:06<00:12, 910.07it/s]
States for depth 4:  21%|██▏       | 3151/14708 [00:06<00:12, 908.41it/s]
States for depth 4:  22%|██▏       | 3243/14708 [00:06<00:12, 909.93it/s]
States for depth 4:  23%|██▎       | 3335/14708 [00:06<00:12, 911.10it/s]
States for depth 4:  23%|██▎       | 3427/14708 [00:06<00:12, 912.00it/s]
States for depth 4:  24%|██▍       | 3519/14708 [00:06<00:12, 906.62it/s]
States for depth 4:  25%|██▍       | 3611/14708 [00:06<00:12, 907.91it/s]
States for depth 4:  25%|██▌       | 3703/14708 [00:06<00:12, 910.18it/s]
States for depth 4:  26%|██▌       | 3795/14708 [00:06<00:11, 911.72it/s]
States for depth 4:  26%|██▋       | 3887/14708 [00:06<00:11, 911.71it/s]
States for depth 4:  27%|██▋       | 3979/14708 [00:07<00:11, 910.31it/s]
States for depth 4:  28%|██▊       | 4071/14708 [00:07<00:11, 910.59it/s]
States for depth 4:  28%|██▊       | 4163/14708 [00:07<00:11, 911.54it/s]
States for depth 4:  29%|██▉       | 4255/14708 [00:07<00:11, 912.35it/s]
States for depth 4:  30%|██▉       | 4347/14708 [00:07<00:11, 913.75it/s]
States for depth 4:  30%|███       | 4439/14708 [00:07<00:11, 913.00it/s]
States for depth 4:  31%|███       | 4531/14708 [00:07<00:11, 913.59it/s]
States for depth 4:  31%|███▏      | 4623/14708 [00:07<00:11, 914.31it/s]
States for depth 4:  32%|███▏      | 4715/14708 [00:07<00:10, 915.28it/s]
States for depth 4:  33%|███▎      | 4807/14708 [00:07<00:10, 914.43it/s]
States for depth 4:  33%|███▎      | 4900/14708 [00:08<00:10, 916.55it/s]
States for depth 4:  34%|███▍      | 4992/14708 [00:08<00:10, 915.81it/s]
States for depth 4:  35%|███▍      | 5084/14708 [00:08<00:10, 915.86it/s]
States for depth 4:  35%|███▌      | 5176/14708 [00:08<00:10, 915.94it/s]
States for depth 4:  36%|███▌      | 5268/14708 [00:08<00:10, 914.81it/s]
States for depth 4:  36%|███▋      | 5360/14708 [00:08<00:10, 913.91it/s]
States for depth 4:  37%|███▋      | 5452/14708 [00:08<00:10, 915.35it/s]
States for depth 4:  38%|███▊      | 5544/14708 [00:08<00:10, 913.77it/s]
States for depth 4:  38%|███▊      | 5636/14708 [00:08<00:09, 913.49it/s]
States for depth 4:  39%|███▉      | 5728/14708 [00:08<00:09, 911.29it/s]
States for depth 4:  40%|███▉      | 5820/14708 [00:09<00:09, 911.97it/s]
States for depth 4:  40%|████      | 5912/14708 [00:09<00:09, 912.79it/s]
States for depth 4:  41%|████      | 6004/14708 [00:09<00:09, 913.09it/s]
States for depth 4:  41%|████▏     | 6096/14708 [00:09<00:09, 914.12it/s]
States for depth 4:  42%|████▏     | 6188/14708 [00:09<00:09, 913.22it/s]
States for depth 4:  43%|████▎     | 6280/14708 [00:09<00:09, 912.86it/s]
States for depth 4:  43%|████▎     | 6373/14708 [00:09<00:09, 915.16it/s]
States for depth 4:  44%|████▍     | 6465/14708 [00:09<00:09, 913.93it/s]
States for depth 4:  45%|████▍     | 6557/14708 [00:09<00:08, 913.46it/s]
States for depth 4:  45%|████▌     | 6649/14708 [00:09<00:08, 913.17it/s]
States for depth 4:  46%|████▌     | 6741/14708 [00:10<00:08, 913.76it/s]
States for depth 4:  46%|████▋     | 6833/14708 [00:10<00:08, 913.10it/s]
States for depth 4:  47%|████▋     | 6925/14708 [00:10<00:08, 911.95it/s]
States for depth 4:  48%|████▊     | 7017/14708 [00:10<00:08, 914.20it/s]
States for depth 4:  48%|████▊     | 7109/14708 [00:10<00:08, 912.62it/s]
States for depth 4:  49%|████▉     | 7201/14708 [00:10<00:08, 912.12it/s]
States for depth 4:  50%|████▉     | 7293/14708 [00:10<00:08, 913.41it/s]
States for depth 4:  50%|█████     | 7385/14708 [00:10<00:08, 912.90it/s]
States for depth 4:  51%|█████     | 7477/14708 [00:10<00:07, 912.30it/s]
States for depth 4:  51%|█████▏    | 7569/14708 [00:10<00:07, 912.15it/s]
States for depth 4:  52%|█████▏    | 7661/14708 [00:11<00:07, 911.83it/s]
States for depth 4:  53%|█████▎    | 7754/14708 [00:11<00:07, 914.70it/s]
States for depth 4:  53%|█████▎    | 7846/14708 [00:11<00:07, 916.20it/s]
States for depth 4:  54%|█████▍    | 7938/14708 [00:11<00:07, 915.74it/s]
States for depth 4:  55%|█████▍    | 8030/14708 [00:11<00:07, 915.16it/s]
States for depth 4:  55%|█████▌    | 8122/14708 [00:11<00:07, 913.34it/s]
States for depth 4:  56%|█████▌    | 8214/14708 [00:11<00:07, 914.15it/s]
States for depth 4:  56%|█████▋    | 8306/14708 [00:11<00:07, 914.04it/s]
States for depth 4:  57%|█████▋    | 8398/14708 [00:11<00:06, 914.73it/s]
States for depth 4:  58%|█████▊    | 8490/14708 [00:11<00:06, 915.00it/s]
States for depth 4:  58%|█████▊    | 8582/14708 [00:12<00:06, 913.27it/s]
States for depth 4:  59%|█████▉    | 8674/14708 [00:12<00:06, 913.87it/s]
States for depth 4:  60%|█████▉    | 8766/14708 [00:12<00:06, 914.76it/s]
States for depth 4:  60%|██████    | 8858/14708 [00:12<00:06, 914.15it/s]
States for depth 4:  61%|██████    | 8950/14708 [00:12<00:06, 914.59it/s]
States for depth 4:  61%|██████▏   | 9042/14708 [00:12<00:06, 914.67it/s]
States for depth 4:  62%|██████▏   | 9134/14708 [00:12<00:06, 914.11it/s]
States for depth 4:  63%|██████▎   | 9226/14708 [00:12<00:05, 913.84it/s]
States for depth 4:  63%|██████▎   | 9318/14708 [00:12<00:05, 914.09it/s]
States for depth 4:  64%|██████▍   | 9410/14708 [00:13<00:05, 914.52it/s]
States for depth 4:  65%|██████▍   | 9502/14708 [00:13<00:05, 914.14it/s]
States for depth 4:  65%|██████▌   | 9594/14708 [00:13<00:05, 915.39it/s]
States for depth 4:  66%|██████▌   | 9686/14708 [00:13<00:05, 914.47it/s]
States for depth 4:  66%|██████▋   | 9778/14708 [00:13<00:05, 914.31it/s]
States for depth 4:  67%|██████▋   | 9870/14708 [00:13<00:05, 914.15it/s]
States for depth 4:  68%|██████▊   | 9962/14708 [00:13<00:05, 915.57it/s]
States for depth 4:  68%|██████▊   | 10054/14708 [00:13<00:05, 916.23it/s]
States for depth 4:  69%|██████▉   | 10146/14708 [00:13<00:04, 917.15it/s]
States for depth 4:  70%|██████▉   | 10238/14708 [00:13<00:04, 915.76it/s]
States for depth 4:  70%|███████   | 10330/14708 [00:14<00:04, 914.79it/s]
States for depth 4:  71%|███████   | 10422/14708 [00:14<00:04, 915.08it/s]
States for depth 4:  71%|███████▏  | 10514/14708 [00:14<00:04, 914.74it/s]
States for depth 4:  72%|███████▏  | 10606/14708 [00:14<00:04, 913.92it/s]
States for depth 4:  73%|███████▎  | 10698/14708 [00:14<00:04, 912.79it/s]
States for depth 4:  73%|███████▎  | 10790/14708 [00:14<00:04, 913.31it/s]
States for depth 4:  74%|███████▍  | 10882/14708 [00:14<00:04, 912.32it/s]
States for depth 4:  75%|███████▍  | 10975/14708 [00:14<00:04, 914.90it/s]
States for depth 4:  75%|███████▌  | 11067/14708 [00:14<00:03, 914.98it/s]
States for depth 4:  76%|███████▌  | 11159/14708 [00:14<00:03, 915.87it/s]
States for depth 4:  76%|███████▋  | 11251/14708 [00:15<00:03, 915.32it/s]
States for depth 4:  77%|███████▋  | 11343/14708 [00:15<00:03, 915.48it/s]
States for depth 4:  78%|███████▊  | 11436/14708 [00:15<00:03, 917.06it/s]
States for depth 4:  78%|███████▊  | 11528/14708 [00:15<00:03, 917.60it/s]
States for depth 4:  79%|███████▉  | 11621/14708 [00:15<00:03, 918.86it/s]
States for depth 4:  80%|███████▉  | 11713/14708 [00:15<00:03, 918.41it/s]
States for depth 4:  80%|████████  | 11805/14708 [00:15<00:03, 918.72it/s]
States for depth 4:  81%|████████  | 11897/14708 [00:15<00:03, 917.02it/s]
States for depth 4:  82%|████████▏ | 11989/14708 [00:15<00:02, 917.11it/s]
States for depth 4:  82%|████████▏ | 12081/14708 [00:15<00:02, 917.50it/s]
States for depth 4:  83%|████████▎ | 12173/14708 [00:16<00:02, 917.15it/s]
States for depth 4:  83%|████████▎ | 12265/14708 [00:16<00:02, 917.12it/s]
States for depth 4:  84%|████████▍ | 12357/14708 [00:16<00:02, 916.77it/s]
States for depth 4:  85%|████████▍ | 12449/14708 [00:16<00:02, 916.75it/s]
States for depth 4:  85%|████████▌ | 12541/14708 [00:16<00:02, 917.08it/s]
States for depth 4:  86%|████████▌ | 12633/14708 [00:16<00:02, 917.03it/s]
States for depth 4:  87%|████████▋ | 12725/14708 [00:16<00:02, 915.02it/s]
States for depth 4:  87%|████████▋ | 12817/14708 [00:16<00:02, 913.76it/s]
States for depth 4:  88%|████████▊ | 12909/14708 [00:16<00:01, 914.82it/s]
States for depth 4:  88%|████████▊ | 13001/14708 [00:16<00:01, 914.81it/s]
States for depth 4:  89%|████████▉ | 13093/14708 [00:17<00:01, 914.10it/s]
States for depth 4:  90%|████████▉ | 13185/14708 [00:17<00:01, 913.90it/s]
States for depth 4:  90%|█████████ | 13277/14708 [00:17<00:01, 913.24it/s]
States for depth 4:  91%|█████████ | 13369/14708 [00:17<00:01, 913.88it/s]
States for depth 4:  92%|█████████▏| 13461/14708 [00:17<00:01, 914.61it/s]
States for depth 4:  92%|█████████▏| 13553/14708 [00:17<00:01, 914.57it/s]
States for depth 4:  93%|█████████▎| 13645/14708 [00:17<00:01, 914.04it/s]
States for depth 4:  93%|█████████▎| 13737/14708 [00:17<00:01, 914.47it/s]
States for depth 4:  94%|█████████▍| 13829/14708 [00:17<00:00, 914.89it/s]
States for depth 4:  95%|█████████▍| 13921/14708 [00:17<00:00, 915.44it/s]
States for depth 4:  95%|█████████▌| 14013/14708 [00:18<00:00, 914.69it/s]
States for depth 4:  96%|█████████▌| 14105/14708 [00:18<00:00, 913.51it/s]
States for depth 4:  97%|█████████▋| 14197/14708 [00:18<00:00, 913.23it/s]
States for depth 4:  97%|█████████▋| 14289/14708 [00:18<00:00, 911.93it/s]
States for depth 4:  98%|█████████▊| 14381/14708 [00:18<00:00, 913.11it/s]
States for depth 4:  98%|█████████▊| 14474/14708 [00:18<00:00, 915.30it/s]
States for depth 4:  99%|█████████▉| 14566/14708 [00:18<00:00, 914.28it/s]
States for depth 4: 100%|█████████▉| 14658/14708 [00:18<00:00, 913.72it/s]
States for depth 4: 100%|██████████| 14708/14708 [00:18<00:00, 782.68it/s]

Reconstructing images from quantum states

To investigate how well the low-depth circuits reproduce the target images, we first reconstruct the pictures encoded in each quantum state. The histogram below reports the fidelity \(F = \left|\langle \psi_{\text{exact}} \mid \psi_{\text{circ.}} \rangle\right|^{2}\), i.e. the overlap between the exact FRQI state $ |:raw-latex:psi_{:raw-latex:`\text{exact}`}:raw-latex:rangle `$ and its 4-layer center-sequential approximation :math:`|psi_{text{circ.}}rangle.

  • Digit 1 samples (orange) cluster at a fidelity \(F\) close to 1, indicating that four layers already capture these images almost perfectly.

  • Digit 0 samples (blue) display a broader, slightly lower-fidelity distribution, hinting at the greater entanglement required to reproduce their curved outline.

On the right we decode the states back into pixel space. In line with the histogram, the reconstructed “1” is virtually indistinguishable from its original, whereas the reconstructed “0” shows minor blurring. By selecting a deeper circuit the quality of the reconstructed images could be improved by trading quality for efficiency.

import matplotlib.pyplot as plt

# Select images with highest fidelity
idx_0 = np.argmax(fidelities_01[labels_01 == 0])
idx_1 = np.argmax(fidelities_01[labels_01 == 1])

orig_0 = FRQI_decoding(exact_state[labels_01 == 0][idx_0][None, :])[0]
orig_1 = FRQI_decoding(exact_state[labels_01 == 1][idx_1][None, :])[0]

rec_0 = FRQI_decoding(states_01[labels_01 == 0][idx_0][None, :])[0]
rec_1 = FRQI_decoding(states_01[labels_01 == 1][idx_1][None, :])[0]

# Create a grid of figures to show both the fidelity distribution and the original and reconstructed images
fig = plt.figure(figsize=(9, 5))
gs = fig.add_gridspec(2, 3, width_ratios=[1.2, 1, 1], wspace=0.05)

# Histogram (spans both rows, leftmost column)
ax_hist = fig.add_subplot(gs[:, 0])
ax_hist.hist(fidelities_01[labels_01 == 0], bins=20, alpha=0.5, label="Digit 0")
ax_hist.hist(fidelities_01[labels_01 == 1], bins=20, alpha=0.5, label="Digit 1")
ax_hist.set_xlabel("Fidelity")
ax_hist.set_ylabel("Count")
ax_hist.legend(loc="upper right")

# Image axes (2 × 2 on the right)
ax00 = fig.add_subplot(gs[0, 1])
ax01 = fig.add_subplot(gs[0, 2])
ax10 = fig.add_subplot(gs[1, 1])
ax11 = fig.add_subplot(gs[1, 2])

ax00.imshow(np.abs(orig_0), cmap="gray")
ax00.set_title("Original 0")
ax01.imshow(np.abs(orig_1), cmap="gray")
ax01.set_title("Original 1")
ax10.imshow(np.abs(rec_0), cmap="gray")
ax10.set_title("Reconstructed 0")
ax11.imshow(np.abs(rec_1), cmap="gray")
ax11.set_title("Reconstructed 1")

# Remove all tick marks from image axes
for ax in [ax00, ax01, ax10, ax11]:
    ax.set_xticks([])
    ax.set_yticks([])
Original 0, Original 1, Reconstructed 0, Reconstructed 1

3. Quantum classifiers

In this demo, we train a variational quantum circuit as classifier. Our datasets require N_QUBITS = 11, therefore we use the same number of qubits for the classifier. Given a data state \(\rho(x)=\lvert\psi(x)\rangle\langle\psi(x)\rvert\), a generic quantum classifier evaluates \(f_{\ell}(x) = \operatorname{Tr}\bigl[ O_{\ell}(\theta)\,\rho(x) \bigr]\), with trainable circuit parameters \(\theta\) that rotate a measurement operator \(O_\ell\). Variants explored in the paper [1] include

  • Linear VQC — sequential two‑qubit SU(4) layers (15 parameters per gate).

  • Non‑linear VQC — gate parameters depend on input data x via auxiliary retrieval circuits.

  • Quantum‑kernel SVMs — replacing inner products by quantum state overlaps.

  • Tensor‑network (MPS/MPO) classifiers for large qubit counts.

In this demo we use a small linear VQC. The circuit consisits of two qubit gates correspdonding to the SO(4) gates

Illustration of the SO(4) decomposition

arranged in the sequential layout.

import optax

# Define the hyperparameters
EPOCHS = 5
BATCH = 128
VAL_FRAC = 0.2
N_QUBITS = 11
DEPTH = 4
N_CLASSES = 2
SEED = 0

# Explicitly compute the number of model parameters
N_PARAMS_FIRST_LAYER = N_QUBITS
N_PARAMS_BLOCK = 4
N_PARAMS_NETWORK = N_PARAMS_FIRST_LAYER + (N_QUBITS - 1) * DEPTH * N_PARAMS_BLOCK

key = jax.random.PRNGKey(SEED)

# Define the model and training functions
dev = qml.device("default.qubit", wires=N_QUBITS)


@jax.jit
@qml.qnode(dev, interface="jax")
def circuit(network_params, state):
    p = iter(network_params)
    qml.StatePrep(state, wires=range(N_QUBITS))

    # First two layers of local RY rotations
    for w in range(N_QUBITS):
        qml.RY(next(p), wires=w)

    # SO(4) building blocks
    for _ in range(DEPTH):
        for j in range(N_QUBITS - 1):
            qml.CNOT(wires=[j, j + 1])
            qml.RY(next(p), wires=j)
            qml.RY(next(p), wires=j + 1)
            qml.CNOT(wires=[j, j + 1])
            qml.RY(next(p), wires=j)
            qml.RY(next(p), wires=j + 1)

    # Probability of computational basis states of the last qubit
    # Can be extended to more qubits for multiclass case
    return qml.probs(N_QUBITS - 1)


model = jax.vmap(circuit, in_axes=(None, 0))


def loss_acc(params, batch_x, batch_y):
    logits = model(params, batch_x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch_y).mean()
    acc = (logits.argmax(-1) == batch_y).mean()
    return loss, acc


# training step
@jax.jit
def train_step(params, opt_state, batch_x, batch_y):
    (loss, acc), grads = jax.value_and_grad(lambda p: loss_acc(p, batch_x, batch_y), has_aux=True)(
        params
    )
    updates, opt_state = opt.update(grads, opt_state, params)
    return optax.apply_updates(params, updates), opt_state, loss, acc


# data loader
def loader(X, y, batch_size, rng_key):
    idx = jax.random.permutation(rng_key, len(X))
    for i in range(0, len(X), batch_size):
        sl = idx[i : i + batch_size]
        yield X[sl], y[sl]

Preparing the training / validation split

We start by casting the FRQI amplitude vectors and their digit labels into JAX arrays. Next, the states and labels are shuffled from a pseudorandom key derived from the global SEED. Then, the data is split into training and validation. Finally, we gather the tensors corresponding in the training (X_train, y_train) and validation sets (X_val, y_val).

from jax import numpy as jnp

# Prepare the data

X_all = jnp.asarray(
    states_01.real, dtype=jnp.float64
)  # we select the real part only, as the the imaginary part is zero since we only use RY and CNOT gates
y_all = jnp.asarray(labels_01, dtype=jnp.int32)

key_split, key_perm = jax.random.split(jax.random.PRNGKey(SEED))
perm = jax.random.permutation(key_perm, len(X_all))
split_pt = int(len(X_all) * (1 - VAL_FRAC))

train_idx = perm[:split_pt]
val_idx = perm[split_pt:]

X_train, y_train = X_all[train_idx], y_all[train_idx]
X_val, y_val = X_all[val_idx], y_all[val_idx]

Training setup and loop

We begin by initializing the network weights params with values drawn uniformly from \([0, 2\pi]\) and initialize the Adam optimizer with a learning rate of \(1 \times 10^{-2}\). The training loop then iterates for EPOCHS and displays the progress via tqdm:

  1. For each mini-batch, train_step performs a forward pass, computes the cross-entropy loss and accuracy, back-propagates gradients, and updates params through the optimizer state opt_state.

  2. Using the current parameters, we evaluate the same metrics on the validation set without gradient updates.

  3. Epoch-mean loss (tl, vl) and accuracy (ta, va) are appended to the tracking lists for later plotting.

The first epoch will take longer than following epochs because of the just-in-time compilation.

from tqdm.auto import trange

# Define the training setup and start the training loop

# optimizer
params = 2 * jnp.pi * jax.random.uniform(key, (N_PARAMS_NETWORK,), dtype=jnp.float64)
opt = optax.adam(1e-2)
opt_state = opt.init(params)

# training loop
rng = key_split
train_loss_curve, val_loss_curve = [], []
train_acc_curve, val_acc_curve = [], []
# for epoch in range(1, EPOCHS + 1):
bar = trange(1, EPOCHS + 1, desc="Epochs", unit="ep")
for epoch in bar:
    # train
    rng, sub = jax.random.split(rng)
    train_losses, train_accs = [], []
    for bx, by in loader(X_train, y_train, BATCH, sub):
        params, opt_state, l, a = train_step(params, opt_state, bx, by)
        train_losses.append(l)
        train_accs.append(a)

    # validation
    val_losses, val_accs = [], []
    for bx, by in loader(X_val, y_val, BATCH, rng):
        l, a = loss_acc(params, bx, by)
        val_losses.append(l)
        val_accs.append(a)

    tl = jnp.mean(jnp.stack(train_losses))
    vl = jnp.mean(jnp.stack(val_losses))
    ta = jnp.mean(jnp.stack(train_accs))
    va = jnp.mean(jnp.stack(val_accs))

    train_loss_curve.append(tl)
    val_loss_curve.append(vl)
    train_acc_curve.append(ta)
    val_acc_curve.append(va)
    bar.set_postfix(
        train_loss=f"{tl:.4f}",
        val_loss=f"{vl:.4f}",
        train_acc=f"{ta:.4f}",
        val_acc=f"{va:.4f}",
    )

# Plot the training curves
(
    fig,
    ax,
) = plt.subplots(1, 2, figsize=(12.8, 4.8))
ax[0].plot(train_loss_curve, label="Train")
ax[0].plot(val_loss_curve, label="Validation")
ax[0].set_xlabel("Epoch")
ax[0].set_ylabel("Loss")
ax[0].legend()
ax[1].plot(train_acc_curve)
ax[1].plot(val_acc_curve)
ax[1].set_xlabel("Epoch")
ax[1].set_ylabel("Accuracy")
low depth circuits mnist
Epochs:   0%|          | 0/5 [00:00<?, ?ep/s]
Epochs:   0%|          | 0/5 [01:02<?, ?ep/s, train_acc=0.6485, train_loss=0.6590, val_acc=0.8644, val_loss=0.5911]
Epochs:  20%|██        | 1/5 [01:02<04:11, 62.77s/ep, train_acc=0.6485, train_loss=0.6590, val_acc=0.8644, val_loss=0.5911]
Epochs:  20%|██        | 1/5 [01:34<04:11, 62.77s/ep, train_acc=0.9426, train_loss=0.5701, val_acc=0.9895, val_loss=0.5615]
Epochs:  40%|████      | 2/5 [01:34<02:13, 44.57s/ep, train_acc=0.9426, train_loss=0.5701, val_acc=0.9895, val_loss=0.5615]
Epochs:  40%|████      | 2/5 [02:06<02:13, 44.57s/ep, train_acc=0.9648, train_loss=0.5562, val_acc=0.9551, val_loss=0.5541]
Epochs:  60%|██████    | 3/5 [02:06<01:17, 38.75s/ep, train_acc=0.9648, train_loss=0.5562, val_acc=0.9551, val_loss=0.5541]
Epochs:  60%|██████    | 3/5 [02:38<01:17, 38.75s/ep, train_acc=0.9661, train_loss=0.5536, val_acc=0.9310, val_loss=0.5533]
Epochs:  80%|████████  | 4/5 [02:38<00:36, 36.02s/ep, train_acc=0.9661, train_loss=0.5536, val_acc=0.9310, val_loss=0.5533]
Epochs:  80%|████████  | 4/5 [03:10<00:36, 36.02s/ep, train_acc=0.9674, train_loss=0.5524, val_acc=0.9711, val_loss=0.5519]
Epochs: 100%|██████████| 5/5 [03:10<00:00, 34.53s/ep, train_acc=0.9674, train_loss=0.5524, val_acc=0.9711, val_loss=0.5519]
Epochs: 100%|██████████| 5/5 [03:10<00:00, 38.03s/ep, train_acc=0.9674, train_loss=0.5524, val_acc=0.9711, val_loss=0.5519]

Text(655.0631313131313, 0.5, 'Accuracy')

Conclusion

In this notebook we have demonstrated the use of low-depth quantum circuits to load and subsequently classify (a subset of) the MNIST dataset.
By filtering to specific target labels, constructing parametrized circuits from the provided layouts, and evaluating their states and fidelities, we have gained hands-on experience with quantum machine learning workflows on real data encodings.

Explore the full set of provided datasets—they contain a variety of different datasets at varying circuit depths, parameterizations, and target classes. You can adapt the presented workflow to different subsets and datasets, experiment with your own models, and contribute back insights on how these benchmark datasets can best support the development of practical quantum machine learning approaches.

References

[1] F.J. Kiwit, B. Jobst, A. Luckow, F. Pollmann and C.A. Riofrío. Typical Machine Learning Datasets as Low-Depth Quantum Circuits. Quantum Sci. Technol. in press (2025). DOI: https://doi.org/10.1088/2058-9565/ae0123.

[2] P.Q. Le, F. Dong and K. Hirota. A flexible representation of quantum images for polynomial preparation, image compression, and processing operations. Quantum Inf. Process 10, 63–84 (2011). DOI: https://doi.org/10.1007/s11128-010-0177-y.

[3] P.Q. Le, A.M. Iliyasu, F. Dong, and K. Hirota. A Flexible Representation and Invertible Transformations for Images on Quantum Computers. In: Ruano, A.E., Várkonyi-Kóczy, A.R. (eds) New Advances in Intelligent Signal Processing. Studies in Computational Intelligence, vol 372. Springer, Berlin, Heidelberg (2011). DOI: https://doi.org/10.1007/978-3-642-11739-8_9.

[4] B. Sun et al. A Multi-Channel Representation for images on quantum computers using the RGBα color space, 2011 IEEE 7th International Symposium on Intelligent Signal Processing, Floriana, Malta, pp. 1-6 (2011). DOI: https://doi.org/10.1109/WISP.2011.6051718.

[5] B. Sun, A. Iliyasu, F. Yan, F. Dong, and K. Hirota. An RGB Multi-Channel Representation for Images on Quantum Computers, J. Adv. Comput. Intell. Intell. Inform., Vol. 17 No. 3, pp. 404–417 (2013). DOI: https://doi.org/10.20965/jaciii.2013.p0404.

[6] B. Jobst, K. Shen, C.A. Riofrío, E. Shishenina and F. Pollmann. Efficient MPS representations and quantum circuits from the Fourier modes of classical image data. Quantum 8, 1544 (2024). DOI: https://doi.org/10.22331/q-2024-12-03-1544.

Appendix

# The CIFAR-10 and Imagenette datasets use the following MCRQI color encoding and decoding [4,5]


def MCRQI_encoding(images):
    """
    Input : (batchsize, N, N, 3) ndarray
        A batch of arrays representing square RGB images.
    Returns : (batchsize, 8, N**2) ndarray
        A batch of quantum states encoding the RGB images using the MCRQI.
    """
    # get image size and number of qubits
    batchsize, N, _, channels = images.shape
    n = 2 * int(np.log2(N))
    # reorder pixels hierarchically
    states = np.reshape(images, (batchsize, *(2,) * n, channels))
    states = np.transpose(
        states,
        [0] + [ax + 1 for q in range(n // 2) for ax in (q, q + n // 2)] + [n + 1],
    )
    # MCRQI encoding by stacking cos and sin components
    states = np.stack(
        [
            np.cos(np.pi / 2 * states[..., 0]),
            np.cos(np.pi / 2 * states[..., 1]),
            np.cos(np.pi / 2 * states[..., 2]),
            np.ones(states.shape[:-1]),
            np.sin(np.pi / 2 * states[..., 0]),
            np.sin(np.pi / 2 * states[..., 1]),
            np.sin(np.pi / 2 * states[..., 2]),
            np.zeros(states.shape[:-1]),
        ],
        axis=1,
    )
    # normalize and reshape
    states = np.reshape(states, (batchsize, 8, N**2)) / (2 * N)
    return states


def MCRQI_decoding(states):
    """
    Input : (batchsize, 8, N**2) ndarray
        A batch of quantum states encoding RGB images using the MCRQI.
    Returns : (batchsize, N, N, 3) ndarray
        A batch of arrays representing the square RGB images.
    """
    # get batchsize and number of qubits
    batchsize = states.shape[0]
    states = np.reshape(states, (batchsize, 8, -1))
    N2 = states.shape[2]
    N = int(np.sqrt(N2))
    n = int(np.log2(N2))
    # invert MCRQI encoding to get pixel values
    images = np.arccos(states[:, :3] ** 2 * 4 * N2 - states[:, 4:7] ** 2 * 4 * N2) / np.pi
    # undo hierarchical ordering
    images = np.reshape(images, (batchsize, 3, *(2,) * n))
    images = np.transpose(images, [0, *range(2, n + 1, 2), *range(3, n + 2, 2), 1])
    # reshape to square image
    images = np.reshape(images, (batchsize, N, N, 3))
    return images

About the authors

Florian Kiwit
Florian Kiwit

Florian Kiwit

Bernhard Jobst
Bernhard Jobst

Bernhard Jobst

Carlos Riofrío
Carlos Riofrío

Carlos Riofrío

Total running time of the script: (6 minutes 1.056 seconds)

Share demo

Ask a question on the forum
PennyLane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Built by researchers, for research. Created with ❤️ by Xanadu.

Research

  • Research
  • Performance
  • Hardware & Simulators
  • Demos
  • Quantum Compilation
  • Quantum Datasets

Education

  • Teach
  • Learn
  • Codebook
  • Coding Challenges
  • Videos
  • Glossary

Software

  • Install PennyLane
  • Features
  • Documentation
  • Catalyst Compilation Docs
  • Development Guide
  • API
  • GitHub
Stay updated with our newsletter

© Copyright 2025 | Xanadu | All rights reserved

TensorFlow, the TensorFlow logo and any related marks are trademarks of Google Inc.

Privacy Policy|Terms of Service|Cookie Policy|Code of Conduct