- Demos/
- Quantum Computing/
Decoding Quantum Errors on the Steane code with Belief Propagation & Catalyst
Decoding Quantum Errors on the Steane code with Belief Propagation & Catalyst
Published: August 25, 2025. Last updated: August 25, 2025.
Learn how to build, simulate, and decode the Steane code using JAX and Catalyst, blending quantum circuits with fast classical decoders in a seamless workflow.

Introduction
This tutorial walks you through a simplified error correction cycle using the Steane $[[7,1,3]]$
code. You’ll encode a logical qubit, introduce noise, extract syndromes, and apply decoding using
two different strategies: a simple lookup table and a belief propagation (BP) decoder. Both decoders
are implemented in JAX and JIT-compiled by Catalyst, allowing everything to run inside a single
@qml.qjit
circuit.
Why is this exciting? Quantum error correction (QEC) is essential for building reliable quantum computers, but it requires more than just quantum operations. Fast classical feedback is needed as well. Catalyst addresses this by fusing the classical and quantum workflows, giving us one unified, hardware-agnostic kernel that runs on CPUs, GPUs, and beyond.
What We’ll Build
By the end of this tutorial, you’ll have:
Encoded a logical $|0⟩$ using the Steane code.
Simulated noise using configurable bit-flip and phase-flip channels.
Extracted syndromes via ancilla-assisted stabilizer measurements.
Decoded errors using:
a Lookup Table (LUT) decoder
a Belief Propagation (BP) decoder
Benchmarked performance across different physical error rates.
Understanding Quantum Error Correction
At its core, QEC protects quantum information through redundant encoding. Stabilizer codes are a foundational tool for this purpose. Each stabilizer is a multi-qubit Pauli operator (composed of I, X, Y, Z gates) that defines parity constraints the code space must satisfy.
Symplectic Representation
To formalize stabilizers and errors, we use symplectic vectors. For \(n\) physical qubits, any \(n\)-qubit Pauli operator can be represented as a binary vector of length \(2n\):
where:
\(v_i = 1\) if the Pauli has an X on qubit \(i\) (0 otherwise),
\(u_i = 1\) if the Pauli has a Z on qubit \(i\) (0 otherwise),
and \(v_i = u_i = 1\) if the Pauli has a Y on qubit \(i\), since \(X \cdot Z \propto Y\).
For example, the operator \(XZIY\) on 4 qubits corresponds to \((1,0,0, 1 | 0,1,0, 1)\). The stabilizer group is then generated by a set of such symplectic vectors, forming a stabilizer matrix of size \(m \times 2n\) (with \(m\) generators).
The commutation condition between two Pauli operators \((v|u)\) and \((v'|u')\) is captured by their symplectic inner product:
For a valid stabilizer code, all generators must commute, which implies that the symplectic inner product between any two rows of the stabilizer matrix must be zero.
Measurement and Syndrome
When an error \(e\) occurs, represented as a symplectic vector \((v_e | u_e)\), the syndrome is calculated by taking the symplectic inner product of \(e\) with each stabilizer generator. This produces a syndrome bit \(s_i\) for each generator:
where \((v_g^{(i)} | u_g^{(i)})\) is the \(i\)-th generator.
Below, we show the basic quantum circuit for extracting a syndrome value. The ancilla qubit is initialized in the \(|+\rangle\) state, and controlled operations are applied based on the stabilizer generators. Finally, the ancilla is measured in the \(X\)-basis to obtain the syndrome value. Check out Arthur Pesah’s excellent blog post series 1 on Stabilizer codes for a deeper introduction.
from typing import Callable, Optional, Dict, Union, Sequence
import pennylane as qml
syndromes = ["XXZIZIX", "XXIIZZI"]
dev = qml.device("lightning.qubit", wires := max(map(len, syndromes)) + 1)
@qml.qnode(device=dev)
def ancilla_assisted_syndrome_extraction(syndromes: list[str]):
ancilla = wires - 1
for i, syndrome in enumerate(syndromes):
qml.Hadamard(ancilla)
for i, s in enumerate(syndrome):
if s == "X":
qml.CNOT(wires=[ancilla, i])
elif s == "Z":
qml.CZ(wires=[ancilla, i])
qml.Hadamard(ancilla)
qml.measure(ancilla)
qml.Barrier()
ancilla += 1
print(qml.draw(ancilla_assisted_syndrome_extraction, show_all_wires=True)(syndromes))
0: ────╭X──────────────────────||────╭X───────────────────||─┤
1: ────│──╭X───────────────────||────│──╭X────────────────||─┤
2: ────│──│──╭Z────────────────||────│──│─────────────────||─┤
3: ────│──│──│─────────────────||────│──│─────────────────||─┤
4: ────│──│──│──╭Z─────────────||────│──│──╭Z─────────────||─┤
5: ────│──│──│──│──────────────||────│──│──│──╭Z──────────||─┤
6: ────│──│──│──│──╭X──────────||────│──│──│──│───────────||─┤
7: ──H─╰●─╰●─╰●─╰●─╰●──H──┤↗├──||────│──│──│──│───────────||─┤
8: ────────────────────────────||──H─╰●─╰●─╰●─╰●──H──┤↗├──||─┤
Decoding: The Classical Half of QEC
Once you have syndrome bits from your stabilizer measurements, you need to figure out what error likely occurred–this is the job of the decoder. Formally, given the syndrome, you’re solving for the most probable error, usually called the maximum likelihood estimate (MLE) for the error.
However, exact MLE decoding depends on the precise information of your noise model and is generally computationally intractable (NP-Hard) because \(n\) one-bit syndrome measurements can take on \(2^n\) unique values. In practice, we rely on approximate methods tuned to assumptions about the noise model.

CSS Codes: Simplifying the Structure
CSS (Calderbank–Shor–Steane) codes are a special class of stabilizer code where the generators are split into X-type and Z-type operators. Their symplectic vectors look like this:
X-type generator: \((v | 0)\) (only Xs)
Z-type generator: \((0 | u)\) (only Zs)
This allows us to represent the stabilizers with two \(m \times n\) matrices:
\(H_X\) for X-type generators
\(H_Z\) for Z-type generators
The commutation condition to ensure that all generators are simultaneously observable is:
which ensures that all X and Z stabilizers commute pairwise. When measuring syndromes:
X-type stabilizers detect Z errors via \(s_X = H_X e_Z^T \pmod{2}\).
Z-type stabilizers detect X errors via \(s_Z = H_Z e_X^T \pmod{2}\).
This separation makes decoding modular, allowing you to handle X and Z errors independently. When we introduce the Steane code later, you’ll see these matrices explicitly and how they simplify syndrome calculation and decoding. See a similar diagram below for the CSS code cycle structure.

The Steane Code
The Steane code is one of the simplest quantum error correcting codes, a CSS code built from two classical Hamming codes. It encodes one logical qubit into seven physical qubits and can correct any single-qubit error. Traditionally, the error correcting ability of a code is referred as the distance or \(d\) and the number of errors a code can correct is \(\lfloor (d-1)/2 \rfloor\). Since the Steane code can correct a single error, it is said to have distance \(3\). This code uses six stabilizer generators:
We’ll start by implementing two decoding strategies:
Lookup Table (LUT): Pre-compute minimal corrections for every syndrome (possible for small codes like this one).
Belief Propagation (BP): An iterative message-passing algorithm that operates on the code’s Tanner graph (a bipartite graph representing the relationships between qubits and stabilizers). It approximates the marginal probabilities of errors on each qubit, offering greater scalability for larger, sparser codes.
We’ll begin with the LUT decoder due to its simplicity and then explore BP, which is more flexible for larger or sparser codes.
Lookup‑table (LUT) decoding
For the Steane code, with \(3\) \(X\) and \(3\) \(Z\) stabilizer generators, there are \(2^3=8\) possible syndromes for both \(X\) and \(Z\). We can create a small table that maps each three‑bit syndrome to a weight‑1 error.
import jax.numpy as jnp
from itertools import combinations
from jax.typing import ArrayLike
import jax
from tabulate import tabulate
def lookup_decoder(matrix: ArrayLike, max_weight: int = 1):
m, n = matrix.shape
lut = jnp.zeros((1 << m, n), dtype=jnp.int8)
# fill table with the lowest‑weight correction for each syndrome
# we do this by iterating over all possible weight one errors and computing their corresponding syndromes
for w in range(1, max_weight + 1):
# iterate over all possible weight-w errors
for qs in combinations(range(n), w):
err = jnp.zeros(n, dtype=jnp.int8).at[jnp.array(qs)].set(1) # error mask
syn = (matrix @ err) % 2 # syndrome for this error
idx = jnp.dot(syn, 1 << jnp.arange(m, dtype=jnp.int8)) # syndrome bits to base 10 index
lut = lut.at[idx].set(err)
@jax.jit
def _decode(syndrome: ArrayLike):
# convert the syndrome to base 10 and look it up in the table
idx = jnp.dot(syndrome, 1 << jnp.arange(m))
return lut[idx]
return _decode
H_steane= jnp.array(
[[0, 0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 0, 1, 1], [1, 0, 1, 0, 1, 0, 1]], dtype=int
)
lut_steane= lookup_decoder(H_steane)
# we see that the steane code has a nice property where counting up in binary shifts the error to the right
table_data = []
for i in range(8):
decoded = lut_steane(jnp.array([int(x) for x in f"{i:03b}"]))
table_data.append([f"{i:03b}", "".join(map(str, decoded))])
print(tabulate(table_data, headers=["Syndrome", "LUT Error"]))
Syndrome LUT Error
---------- -----------
000 0000000
001 1000000
010 0100000
011 0010000
100 0001000
101 0000100
110 0000010
111 0000001
While this approach is optimal for small codes, it rapidly becomes infeasible for larger examples. For instance, the distance-\(30\) rotated surface code, which encodes only \(1\) logical qubits, has \(450\) stabilizers for both \(X\) and \(Z\). A full lookup table decoder for just one check type would require approximately \(2.9\times 10^{35}\) entries.
Belief-Propagation (BP) Decoder
Belief propagation is an iterative message-passing algorithm used to decode errors by working on the Tanner graph 2 of the code. This graph has two types of nodes:
Variable nodes represent the physical qubits, which may or may not have experienced an error. These correspond to the bits of the error vector \(e = (e_1, e_2, \dots, e_n)\).
Check nodes represent stabilizers, which enforce parity constraints on subsets of qubits. Each check node corresponds to a row of the parity-check matrix \(H\).
There is an edge between a check node \(c\) and a variable node \(v\) if and only if \(H_{cv} = 1\), meaning that qubit \(v\) participates in stabilizer \(c\).
The goal is to estimate the probability that each qubit has been flipped (i.e., that \(e_v = 1\)), given the observed syndrome bits \(s_c\). BP updates there beliefs iteratively by exchanging messages between variable and check nodes.
The Sum-Product Algorithm
The BP decoder is based on the sum-product algorithm, which computes marginal probabilities over the graph. Here’s the procedure:
Initialization
Each variable node \(v\) sends an initial message to its neighboring checks that reflects the intrinsic belief about whether an error has occurred. This is the log-likelihood ratio (LLR) based on the physical error rate \(p\):
\[L_0 = \log\frac{1 - p}{p}\]This expresses the prior belief: if \(p\) is small (e.g., 0.01), then \(L_0\) is positive, favoring no error; if \(p\) is close to 0.5, \(L_0\) is near zero (no strong prior). In general \(p\) is a parameter of the algorithm that can be tuned to your specific noise source.
Variable-to-Check Update
Each variable node updates its message to a neighboring check \(c\) by combining its intrinsic belief with the incoming messages from other connected checks:
\[m_{v \to c} = L_0 + \sum_{c' \in N(v) \setminus c} m_{c' \to v}\]Here:
\(m_{v \to c}\) is the message from variable \(v\) to check \(c\).
\(N(v)\) is the set of checks connected to variable \(v\).
\(m_{c' \to v}\) are messages received from neighboring checks other than \(c\).
Check-to-Variable Update
Each check node updates its message to a neighboring variable \(v\) based on the syndrome bit \(s_c\) and the incoming messages from the other variables connected to it:
\[m_{c \to v} = (-1)^{s_c} \; 2 \, \operatorname{arctanh} \biggl( \prod_{v' \in N(c) \setminus v} \tanh\frac{m_{v' \to c}}{2} \biggr)\]Here:
\(m_{c \to v}\) is the message from check \(c\) to variable \(v\).
\(s_c\) is the syndrome bit for check \(c\) (0 if the stabilizer is satisfied, 1 if violated).
\(N(c)\) is the set of variables connected to check \(c\).
The \(\tanh\) and \(\operatorname{arctanh}\) functions implement the sum-product rule for combining binary parity checks derived from classical probability theory.
What’s going on? This formula indicates that if the product of incoming \(\tanh\) terms is close to +1 or -1, it means there is a strong belief about whether the parity is satisfied or violated. The \(\operatorname{arctanh}\) converts that back into an LLR-style message. The \((-1)^{s_c}\) factor flips the sign if the syndrome is 1, signaling that a parity error was detected.
Iteration
Steps 2 and 3 are repeated for a fixed number of iterations (e.g., 10–20) or until the messages converge (i.e., stop changing significantly). Traditional theory and heuristics in error correction say to repeat \(BP\) roughly on the order of \(O(n)\).
Decision Rule
After the iterations, each variable node computes its posterior LLR by summing its intrinsic belief and all incoming messages:
\[L_v = L_0 + \sum_{c \in N(v)} m_{c \to v}\]The decoder then makes a hard decision:
If \(L_v < 0\), it guesses \(e_v = 1\) (error detected).
If \(L_v > 0\), it guesses \(e_v = 0\) (no error).
Why This Works
Belief propagation is exact on tree-like graphs, where no cycles exist. However, even on Tanner graphs, which are never tree-like, it provides a good approximation to the maximum-likelihood decoder by using only local, iterative computations. Nevertheless, its performance can degrade when the Tanner graph contains many short cycles—a common characteristic of many popular quantum codes, which can lead to poor convergence. In practice, further extensions like BP-OSD 3, BP-LSD 4 or Ambiguity Clustering 5 are used to fix these issues.
See the following summary article 6 as well as Chapter 5 in Bayesian Reasoning and Machine Learning 7 for a deeper dive into message passing algorithms on graphs.
BP in JAX
Below, we implement a BP decoder using JAX broken down into it’s core components.
Before we can pass messages, we need to establish the connectivity between nodes. The _build_graph
function scans the
parity‑check matrix once and records, for every variable node, which checks touch it and vice versa.
We convert the neighbour lists to tuples so they become immutable, hashable static data. JAX can
then embed their values as compile‑time constants in the XLA program and reliably reuse the compiled
kernel multiple times. A cool thing about JAX/XLA is that when using simple static parameters like the ones
below, the individual integers it contains are baked into the XLA program as compile‑time constants,
so we can truly compile a high performance decoder for our specific parity check matrix.
def _build_graph(
pcm: ArrayLike,
) -> tuple[tuple[tuple[int, ...], ...], tuple[tuple[int, ...], ...]]:
"""
Pre‑compute variable‑node and check‑node neighbors.
Returns
-------
var_neighbors : tuple[tuple[int, ...], ...] # length = n
check_neighbors : tuple[tuple[int, ...], ...] # length = m
"""
m, n = pcm.shape
vars_, checks_ = [[] for _ in range(n)], [[] for _ in range(m)]
for c in range(m):
for v in range(n):
if pcm[c, v]:
vars_[v].append(c)
checks_[c].append(v)
return tuple(map(tuple, vars_)), tuple(map(tuple, checks_))
A nice way to visulaize this Tanner graph is using the networkx
package. Below is an example on
the Steane code.
import matplotlib.pyplot as plt
import networkx as nx
vars, checks = _build_graph(H_steane)
G = nx.Graph()
num_vars = len(vars)
num_checks = len(checks)
# build the nx graph object from our vars and checks
for v in range(num_vars):
G.add_node(f"v{v}", bipartite=0)
for c in range(num_checks):
G.add_node(f"c{c}", bipartite=1)
for c in range(num_checks):
for v in checks[c]:
G.add_edge(f"c{c}", f"v{v}")
pos = nx.bipartite_layout(G, nodes=[f"v{i}" for i in range(num_vars)])
plt.figure(figsize=(10, 7))
nx.draw(G, pos, with_labels=True, node_color="skyblue", node_size=500, font_weight="bold")
plt.title("Bipartite Graph for H_steane", fontsize=16)
plt.show()

The _c2v_update
helper function performs one full sweep of check‑to‑variable updates (step 3 of the
sum‑product algorithm). It takes the previous messages, the syndrome, the neighbor tables, and two
scalars (L_int
for the intrinsic log‑likelihood ratio and eps
for numerical safety). It
loops only over existing edges, multiplies the relevant \(\operatorname{tanh}\) terms, clips
the product, applies \(\operatorname{arctanh}\), and writes the new message into the next
matrix.
def _c2v_update(
m_c2v_prev: ArrayLike,
syndrome: ArrayLike,
var_nei: tuple[tuple[int, ...], ...],
check_nei: tuple[tuple[int, ...], ...],
L_int: float,
eps: float,
) -> ArrayLike:
"""
Compute the next round of check‑to‑variable messages.
"""
m, n = m_c2v_prev.shape
m_c2v_next = jnp.zeros_like(m_c2v_prev)
# Loop over checks (outer) then their vars (inner)
for c in range(m):
Vc = check_nei[c]
if len(Vc) < 2:
continue # degree‑1 checks carry no new info
for v in Vc:
prod = 1.0
# product over all v' ≠ v in this check
for v_p in Vc:
if v_p == v:
continue
incoming = L_int
for c_p in var_nei[v_p]:
if c_p != c:
incoming += m_c2v_prev[c_p, v_p]
prod *= jnp.tanh(incoming / 2.0)
prod = jnp.clip(prod, -1.0 + eps, 1.0 - eps)
msg = ((-1) ** syndrome[c]) * 2.0 * jnp.arctanh(prod)
m_c2v_next = m_c2v_next.at[c, v].set(msg)
return m_c2v_next
Once the main loop finishes, we still need a hard decision. The function _posterior_llrs
folds every final
check‑to‑variable message for bit v
into its intrinsic LLR, yielding the posterior belief for
that bit. A negative value means “error likely,” a positive value means “no error.”
def _posterior_llrs(
m_c2v_final: ArrayLike, var_nei: tuple[tuple[int, ...], ...], L_int: float
) -> ArrayLike:
"""
Combine intrinsic LLR with all incoming messages.
"""
n = m_c2v_final.shape[1]
llr = jnp.full(n, L_int)
for v in range(n):
for c in var_nei[v]:
llr = llr.at[v].add(m_c2v_final[c, v])
return llr
build_bp_decoder
serves as the main entry point for compiling our decoder. It takes the parity‑check
matrix and channel error rate, builds the parity graph, pre‑computes the intrinsic LLR, and returns
a JIT‑compiled function _decode
.
Inside _decode
, the following steps are executed:
All messages are zero-initialized.
_c2v_update
is called inside ajax.lax.fori_loop
formax_iter
rounds.Final messages are converted to posterior LLRs with
_posterior_llrs
.A binary error vector is output by thresholding the LLRs at zero.
Because the whole _decode
body is wrapped in @jax.jit
, the first call compiles everything
into an XLA kernel; subsequent calls run at full device speed.
def build_bp_decoder(
parity_check_matrix: ArrayLike,
error_rate: float,
max_iter: int = 10,
epsilon: float = 1e-9,
) -> Callable[[ArrayLike], ArrayLike]:
"""
Return a JIT‑compiled BP decoder for the given code and channel.
Parameters
----------
parity_check_matrix : array‑like (m, n)
error_rate : float # BSC crossover probability p
max_iter : int
epsilon : float # numerical safety margin
"""
pcm = jnp.asarray(parity_check_matrix, dtype=jnp.int32)
m, n = pcm.shape
L_int = jnp.log((1.0 - error_rate) / error_rate)
var_nei, check_nei = _build_graph(pcm)
@jax.jit
def _decode(syndrome: ArrayLike) -> ArrayLike:
syndrome = jnp.asarray(syndrome, dtype=jnp.int32)
# Initialise all messages to zero
m_c2v = jnp.zeros((m, n), dtype=jnp.float32)
# BP loop
def body(_, msgs):
return _c2v_update(msgs, syndrome, var_nei, check_nei, L_int, epsilon)
m_c2v = jax.lax.fori_loop(0, max_iter, body, m_c2v)
# Hard decision from posterior LLRs
llr = _posterior_llrs(m_c2v, var_nei, L_int)
return (llr < 0).astype(jnp.int32)
# optionally we can force our decoder to compile right away by calling it on a test input
_decode(jnp.zeros(m, dtype=jnp.int32))
return _decode
Let’s test the performance of the BP decoder on the Steane code compared to the LUT decoder.
bp_steane = build_bp_decoder(H_steane, error_rate=0.05, max_iter=7)
n_bits = H_steane.shape[0]
correct = 0
total_syndromes = 2**n_bits
table_data = []
headers = ["Syndrome", "BP Estimated Error", "LUT Exact Error", "Match"]
for i in range(total_syndromes):
syndrome_binary_string = f"{i:0{n_bits}b}"
s_array = jnp.array([int(x) for x in syndrome_binary_string])
# Get error patterns from BP decoder and LUT
bp_pattern = bp_steane(s_array)
lut_pattern = lut_steane(s_array)
match = jnp.all(bp_pattern == lut_pattern)
bp_pattern_str = "".join(map(str, bp_pattern.tolist()))
lut_pattern_str = "".join(map(str, lut_pattern.tolist()))
table_data.append([syndrome_binary_string, bp_pattern_str, lut_pattern_str, str(match)])
# Increment correct count if patterns match
if match:
correct += 1
print(tabulate(table_data, headers=headers))
# Calculate and print the BP accuracy
accuracy = (correct / total_syndromes) * 100 if total_syndromes > 0 else 0
print(f"\nBP Accuracy: {accuracy:.2f}%")
Syndrome BP Estimated Error LUT Exact Error Match
---------- -------------------- ----------------- -------
000 0000000 0000000 True
001 1000000 1000000 True
010 0100000 0100000 True
011 0010000 0010000 True
100 0001000 0001000 True
101 0000100 0000100 True
110 0000010 0000010 True
111 0000001 0000001 True
BP Accuracy: 100.00%
Before diving into the code, let’s test our belief‑propagation (BP) decoder on a bigger example: the n‑bit repetition code. This code stores each logical bit by repeating it \(n\) times (e.g. \(0 \mapsto 00\ldots0\) and \(1 \mapsto 11\ldots1\)). Its parity‑check matrix consists of \(n-1\) rows, each enforcing that two neighbouring bits are equal. Below, we measure how often the BP decoder corrects random errors on a 50‑bit repetition code and compare its success rate to an optimal maximum‑likelihood (ML) decoder, which simply picks the lower‑weight error pattern consistent with the observed syndrome.
def rep_code(n: int) -> ArrayLike:
"""
Build the (n − 1) × n parity‑check matrix H for the [n, 1] repetition code.
Each row enforces equality between two neighboring bits:
H[i] has 1s in positions i and i+1, zeros elsewhere.
"""
# First row: parity check on bits 0 and 1 → [1, 1, 0, 0, …, 0]
first_row = jnp.zeros(n, dtype=jnp.int8).at[jnp.array([0, 1])].set(1)
rows = [first_row]
# Remaining rows: slide the two‑bit “window” to the right
for _ in range(n - 2):
rows.append(jnp.roll(rows[-1], 1)) # shift previous row by 1 position
return jnp.stack(rows) # shape = (n‑1, n)
@jax.jit
def ml_rep_decoder(syndrome: ArrayLike) -> ArrayLike:
"""
Minimum‑weight decoder for the repetition code.
Parameters
----------
syndrome : ArrayLike, shape (n‑1,)
The syndrome s = H e (mod 2).
Returns
-------
error : ArrayLike, shape (n,)
A lowest‑weight error vector consistent with `syndrome`.
"""
# Candidate 1: assume e[0] = 0, then recover the rest via cumulative XOR.
# e[k+1] = e[k] ⊕ s[k] ⇒ e = [0, cumsum(s) mod 2]
e0 = jnp.concatenate((jnp.array([0], dtype=jnp.int32), jnp.mod(jnp.cumsum(syndrome), 2)))
# Candidate 2: flip every bit (equivalent to choosing e[0] = 1).
e1 = (e0 + 1) & 1 # fast “add‑one then mod 2”
# Compare Hamming weights.
w0, w1 = jnp.sum(e0), jnp.sum(e1)
# Return the lighter candidate (ties resolved in favour of e0).
return jax.lax.cond(w0 <= w1, lambda _: e0, lambda _: e1, operand=None)
We run a short experiment on a \(50\) bit repetition code. We sample 10,000 random syndromes
vectors and compute the accuracy of our BP decoder compared to our baseline ml_rep_decoder
H_rep = rep_code(n := 50)
bp_rep = build_bp_decoder(parity_check_matrix=H_rep, error_rate=0.1, max_iter=n)
# sample random syndromes
N = 10_000
key = jax.random.PRNGKey(0)
syndromes = jax.random.randint(key, shape=(N, n - 1), minval=0, maxval=2)
# use jax to map the decoder over the syndromes
# since our decoders are jit compiled jax functions they can be used with jax.vmap
success_rate = jnp.mean(
jnp.all(jax.vmap(ml_rep_decoder)(syndromes) == jax.vmap(bp_rep)(syndromes), axis=1)
)
print(f"Decoding success rate: {success_rate * 100:.2f}%")
Decoding success rate: 85.73%
Catalyst hybrid kernel
Now that we understand a good chunk of theory behind CSS codes, the Steane code and decoding algorithms, let’s put this into action with Catalyst!
Catalyst lets us build hybrid quantum-classical workflows, compiling both quantum operations and classical decoding logic into a single, efficient kernel. We’ll start with a quantum-classical circuit to prepare the logical zero state \(|0\rangle_L\) for our Steane code. This method is also general for initializing logical zero states for any CSS codes.
Start with a \(+1\) eigenstate (or stabilizer state) of all the \(Z\)-type stabilizers. The \(|0\ldots 0\rangle\) is always stabilized by any \(Z\)-type Pauli operator, making it a suitable choice.
Then, for each X-type generator:
Prepare an ancilla qubit in the \(|+\rangle\) state.
Measure X-type stabilizers using CNOT operations onto an ancilla.
Measure in the \(X\) basis.
Next:
Use measurement outcomes (syndromes) to determine necessary corrections using our decoder.
Apply Z-type corrections based on decoding results.
This procedure uses projective measurements to force the data qubits to be in the \(+1\) eigenstate of our \(X\)-type generators. Since the state was already a \(+1\) eigenstate of our \(Z\)-type generators, and by virtue of the CSS code all \(X\) and \(Z\) generators simultaneously commute, we are left with a state in the \(+1\) eigenspace of all the generators.
import pennylane as qml
from jax import random
import catalyst
r, n = H_steane.shape
n_wires = n + r
dev = qml.device("lightning.qubit", wires=n_wires)
def measure_x_stabilizers(H: ArrayLike):
"""
Measure all X type stabilizers based on the parity check matrix X then apply Z type corrections from our decoder
:param H: Parity check X matrix
"""
r, n = H.shape
# Encode logical |0>
# (Hadamard on ancillas, controlled X stabilizers)
for a in range(r):
qml.H(wires=n + a)
for a, row in enumerate(H):
for q, x in enumerate(row):
if x:
qml.CNOT(wires=[n + a, q])
for a in range(r):
qml.H(wires=n + a)
# Measure + reset ancillas (X stabilizers)
sx = jnp.stack([catalyst.measure(n + a) for a in range(r)])
for a, bit in enumerate(sx):
if bit:
qml.PauliX(wires=n + a) # reset ancilla
# Z‑correction
# Since the BP and LUT decoder
# we're both perfect on the Steane code
# well use the LUT for simplicity
rec_z = lut_steane(sx)
for q, bit in enumerate(rec_z):
if bit:
qml.PauliZ(wires=q)
@qml.qjit(autograph=True)
@qml.qnode(dev)
def encode_zero_steane():
measure_x_stabilizers(H_steane)
return qml.state()
A simple utility function to display the state
from pprint import pprint
def state_vector_to_dict(
sv: ArrayLike,
wires: Optional[Sequence[int]],
eps: float = 1e-8,
probability: bool = False,
display: bool = True,
) -> Dict[str, Union[float, complex]]:
"""
Convert a state vector into {bitstring: amplitude | probability}.
"""
n_qubits = int(jnp.log2(len(sv)))
out: Dict[str, Union[float, complex]] = {}
for idx, amp in enumerate(sv):
mag = jnp.abs(amp) ** 2 if probability else jnp.abs(amp)
if mag <= eps:
continue
bitstring = f"{idx:0{n_qubits}b}"
key = "".join(b for i, b in enumerate(bitstring) if wires is None or i in wires)
if probability:
out[key] = out.get(key, 0.0) + float(mag)
else:
out[key] = amp.item()
if display:
pprint(out)
return out
We run the encode_zero
function and see that we recover the correct logical zero state for the
Steane code:
sv_clean = encode_zero_steane()
state_vector_to_dict(sv_clean, display=True, wires=range(n))
{'0000000': (0.35355339059327373+0j),
'0001111': (0.35355339059327373+0j),
'0110011': (0.35355339059327373-0j),
'0111100': (0.35355339059327373-0j),
'1010101': (0.35355339059327373-0j),
'1011010': (0.35355339059327373-0j),
'1100110': (0.35355339059327373+0j),
'1101001': (0.35355339059327373+0j)}
{'0000000': (0.35355339059327373+0j), '0001111': (0.35355339059327373+0j), '0110011': (0.35355339059327373-0j), '0111100': (0.35355339059327373-0j), '1010101': (0.35355339059327373-0j), '1011010': (0.35355339059327373-0j), '1100110': (0.35355339059327373+0j), '1101001': (0.35355339059327373+0j)}
Simulating Errors and Full Correction
We’re now ready to wrap everything together:
Prepare the zero state.
Simulate noise using a depolarizing channel.
Perform one complete round of stabilizer measurements and corrections.
def noise_channel(n: int, p_err: float, key: random.PRNGKey):
"""
Apply a single‑qubit Pauli noise channel independently to each of `n` qubits.
For every qubit the channel does:
0 → I with probability 1 - p_err
1 → X with probability p_err / 3
2 → Z with probability p_err / 3
3 → Y with probability p_err / 3
"""
probs = jnp.array([1.0 - p_err, p_err / 3, p_err / 3, p_err / 3])
outcomes = random.choice(key, 4, shape=(n,), p=probs)
for idx, outcome in enumerate(outcomes):
if outcome == 1:
qml.X(wires=idx)
elif outcome == 2:
qml.Z(wires=idx)
elif outcome == 3:
qml.Y(wires=idx)
# this is a helper function to get the specific error we used in a given round based on the key
def get_error(n: int, p_err: float, key: random.PRNGKey):
err = []
probs = jnp.array([1.0 - p_err, p_err / 3, p_err / 3, p_err / 3])
outcomes = random.choice(key, 4, shape=(n,), p=probs)
for idx, outcome in enumerate(outcomes):
if outcome == 1:
err.append(qml.X(wires=idx))
elif outcome == 2:
err.append(qml.Z(wires=idx))
elif outcome == 3:
err.append(qml.Y(wires=idx))
return qml.ops.prod(*err)
Similar to measure_x_stabilizers
, however, we now apply CNOT from data to an ancilla prepared
in the \(|0\rangle\) state and perform a \(Z\)-basis measurement.
def measure_z_stabilizers(H):
r, n = H.shape
for a, row in enumerate(H):
for q, x in enumerate(row):
if x:
qml.CNOT(wires=[q, n + a])
sz = jnp.stack([catalyst.measure(n + a) for a in range(r)])
for a, bit in enumerate(sz):
if bit:
qml.PauliX(wires=n + a)
rec_x = lut_steane(sz)
for q, bit in enumerate(rec_x):
if bit:
qml.PauliX(wires=q)
Now, let’s run the qec_round
using state preparation, followed by one round of noise injection and one round of \(X\)
and \(Z\) correction. We’ll print the error that occurred in our noisy channel and demonstrate that the output state closely resembles the noiseless state we observed previously.
@qml.qjit(autograph=True)
@qml.qnode(dev, interface="jax")
def qec_round(H: ArrayLike, p_err=1e-3, key=random.PRNGKey(0)):
"""One round of Steane code QEC with LUT decoding."""
measure_x_stabilizers(H) # prepare 0 state
noise_channel(n, p_err, key) # inject IID pauli noise
measure_x_stabilizers(H) # correct X errors
measure_z_stabilizers(H) # correct Z errors
return qml.state()
p_err = 0.1
key = random.PRNGKey(10)
print(f"Running Steane Code QEC Round with error: {get_error(n, p_err=p_err, key=key)}")
state_vector_to_dict(qec_round(H_steane, p_err=p_err, key=key), display=True, wires=range(n))
Running Steane Code QEC Round with error: X(0) @ X(6)
{'0010110': (0.3535533905932738+0j),
'0011001': (0.3535533905932738+0j),
'0100101': (0.3535533905932738+0j),
'0101010': (0.3535533905932738+0j),
'1000011': (0.3535533905932738+0j),
'1001100': (0.3535533905932738+0j),
'1110000': (0.3535533905932738+0j),
'1111111': (0.3535533905932738+0j)}
{'0010110': (0.3535533905932738+0j), '0011001': (0.3535533905932738+0j), '0100101': (0.3535533905932738+0j), '0101010': (0.3535533905932738+0j), '1000011': (0.3535533905932738+0j), '1001100': (0.3535533905932738+0j), '1110000': (0.3535533905932738+0j), '1111111': (0.3535533905932738+0j)}
If we increase the likelihood of errors, we are more likely to end up with an error pattern that can’t be corrected.
p_err = 0.3
key = random.PRNGKey(8)
print(f"Running Steane Code QEC Round with error: {get_error(n, p_err=p_err, key=key)}")
state_vector_to_dict(qec_round(H_steane, p_err=p_err, key=key), display=True, wires=range(n))
Running Steane Code QEC Round with error: X(3) @ X(6)
{'0010110': (0.3535533905932738+0j),
'0011001': (0.3535533905932738+0j),
'0100101': (0.3535533905932738+0j),
'0101010': (0.3535533905932738+0j),
'1000011': (0.3535533905932738+0j),
'1001100': (0.3535533905932738+0j),
'1110000': (0.3535533905932738+0j),
'1111111': (0.3535533905932738+0j)}
{'0010110': (0.3535533905932738+0j), '0011001': (0.3535533905932738+0j), '0100101': (0.3535533905932738+0j), '0101010': (0.3535533905932738+0j), '1000011': (0.3535533905932738+0j), '1001100': (0.3535533905932738+0j), '1110000': (0.3535533905932738+0j), '1111111': (0.3535533905932738+0j)}
Benchmarking logical vs. physical error rates
In the final section of this demo, we will compute the average performance of our Steane code error
correction circuit for a range of possible error rates. We’ll define logical error rates by
comparing the state vector from our noisy simulation with the clean state vector sv_clean
of
the Steane code logical zero state.
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
clean_idx = jnp.where(sv_clean)[0]
def logical_error(sv):
st = sv[clean_idx]
return 1 - jnp.all(jnp.isclose(st / st[0], 1))
Simulate 1000 noisy shots for several noise levels. We’ll use jax.vmap
to efficiently map our
catalyst kernel over a set of random keys.
@jax.jit
def single_trial_error(key, p_err, H_steane):
"""Performs one QEC round and checks for a logical error."""
round_output = qec_round(H_steane, p_err, key)
err = logical_error(round_output)
return err
batch_trial_errors = jax.vmap(single_trial_error, in_axes=(0, None, None))
N = 1000
p_rng = 2 ** jnp.arange(-5, -1.75, 0.25, dtype=jnp.float32)
res = []
master_key = random.PRNGKey(0)
for p in tqdm(p_rng):
keys_for_batch, master_key = random.split(master_key)
all_keys = random.split(keys_for_batch, N)
errors_batch = batch_trial_errors(all_keys, p, H_steane)
p_value = p.item()
for idx, err in enumerate(errors_batch):
res.append({"p": p_value, "seed": idx, "err": err.item()})
df = pd.DataFrame(res)
0%| | 0/13 [00:00<?, ?it/s]
8%|▊ | 1/13 [00:02<00:24, 2.02s/it]
15%|█▌ | 2/13 [00:03<00:15, 1.45s/it]
23%|██▎ | 3/13 [00:04<00:13, 1.38s/it]
31%|███ | 4/13 [00:05<00:11, 1.25s/it]
38%|███▊ | 5/13 [00:06<00:10, 1.29s/it]
46%|████▌ | 6/13 [00:07<00:08, 1.21s/it]
54%|█████▍ | 7/13 [00:08<00:06, 1.16s/it]
62%|██████▏ | 8/13 [00:10<00:06, 1.23s/it]
69%|██████▉ | 9/13 [00:11<00:04, 1.17s/it]
77%|███████▋ | 10/13 [00:12<00:03, 1.13s/it]
85%|████████▍ | 11/13 [00:13<00:02, 1.24s/it]
92%|█████████▏| 12/13 [00:14<00:01, 1.18s/it]
100%|██████████| 13/13 [00:15<00:00, 1.14s/it]
100%|██████████| 13/13 [00:15<00:00, 1.22s/it]
Plot the results using seaborn
p_rng_min = p_rng[0]
p_rng_max = p_rng[-1]
sns.set_theme(style="whitegrid", context="talk")
plt.figure(figsize=(10, 7))
sns.lineplot(
data=df,
x="p",
y="err",
marker="o",
markersize=8,
linewidth=2.5,
label="Simulated Logical Error Rate",
)
plt.plot(
[p_rng_min, p_rng_max],
[p_rng_min, p_rng_max],
linestyle="--",
color="gray",
linewidth=1.5,
label="$p_{physical} = p_{logical}$", # Label for legend
)
plt.xlabel("Physical Error Rate ($p$)", fontsize=16)
plt.ylabel("Logical Error Rate ($P_L$)", fontsize=16)
plt.xscale("log", base=2)
plt.yscale("log", base=2)
plt.title("Logical vs. Physical Error Rate", fontsize=18, pad=20)
plt.legend(fontsize=14)
plt.grid(True, which="both", ls="--", c="lightgray", alpha=0.7) # 'both' for major and minor ticks
sns.despine()
plt.tight_layout()
plt.show()

Conclusion and Limitations
In this tutorial, we successfully built, simulated, and decoded a simple quantum error correction cycle using the Steane code. We demonstrated encoding a logical qubit, introduced errors through noise simulation, and performed error correction using stabilizer measurements combined with classical decoding. Performance was benchmarked by measuring the logical versus physical error rates.
However, our approach relied on a significant simplifying assumption known as the code capacity model, where errors are assumed to occur at only one stage of the circuit, with otherwise perfect encoding and syndrome extraction. A more realistic approach—called circuit-level noise—accounts for potential errors at every gate and measurement within the circuit. This model significantly complicates decoding, as it requires mapping every possible error location not only in space but also across multiple syndrome measurement rounds, forming a complex space-time hypergraph. Decoding then involves interpreting error events over both spatial and temporal dimensions.
Nevertheless, the fundamental decoding principles explored here, particularly the Belief Propagation algorithm, remain highly relevant. BP is flexible enough to operate effectively on more comprehensive circuit-level decoding hypergraphs.
References
- 1
Pesah, Arthur. “The stabilizer trilogy I — Stabilizer codes.” Arthur Pesah, 31 Jan. 2023, https://arthurpesah.me/blog/2023-01-31-stabilizer-formalism-1/.
- 2
Wiberg, Niclas. (2001). Codes and Decoding on General Graphs. https://www.essrl.wustl.edu/~jao/itrg/wiberg.pdf
- 3
Panteleev, Pavel. “Degenerate Quantum LDPC Codes With Good Finite Length Performance.” arXiv.org, 04 Apr. 2019, https://arxiv.org/abs/1904.02703v3.
- 4
Hillmann, Timo. “Localized statistics decoding: A parallel decoding algorithm for quantum low-density parity-check codes.” arXiv.org, 26 Jun. 2024, https://arxiv.org/abs/2406.18655v1.
- 5
Wolanski, Stasiu. “Ambiguity Clustering: an accurate and efficient decoder for qLDPC codes.” arXiv.org, 20 Jun. 2024, https://arxiv.org/abs/2406.14527v2.
- 6
Loeliger, Hans-Andrea. “An introduction to factor graphs” in IEEE Signal Processing Magazine, vol. 21, no. 1, pp. 28-41, Jan. 2004, https://www.isiweb.ee.ethz.ch/papers/arch/aloe-2004-spmagffg.pdf.
- 7
Barber, David. “Bayesian Reasoning and Machine Learning”. Cambridge University Press, USA. 2012, http://web4.cs.ucl.ac.uk/staff/D.Barber/textbook/180325.pdf#page=107.50
About the author
Tom Ginsberg
Tom leads BEIT's efforts in quantum error correction & fault tolerant compilation.
Total running time of the script: (0 minutes 32.375 seconds)