Learning shallow quantum circuits with local inversions and circuit sewing¶
Published: January 24, 2024. Last updated: October 7, 2024.
Note
Go to the end to download the full example code.
In a recent paper Learning shallow quantum circuits 1, Huang et al introduce and prove performance bounds on efficient algorithms to learn constant depth circuits. At the heart of the paper lie local inversions that locally undo a quantum circuit, as well as a circuit “sewing” technique that lets one construct a global inversion from those. We are going to review these new concepts and showcase them with an implementation in PennyLane.
Introduction¶
Shallow, constant depth quantum circuits are provably powerful 2. At the same time, they are known to be difficult to train 3. The authors of 1 tackle the question of whether or not shallow circuits are efficiently learnable.
Given some unknown unitary circuit $U,$ learning the circuit constitutes finding a unitary $V$ that faithfully resembles $U's$ action. This can be either fully performing the same operation ($U V^\dagger = 1$) or resembling the action on a fixed input state ($U |\phi\rangle = V |\phi\rangle,$ where often $|\phi\rangle = |0 \rangle^{\otimes n}$). The authors go through both scenarios with different levels of restrictions on the allowed gate set and locality of the target circuit $U.$ In this demo, we are mainly going to focus on learning the action on $|0 \rangle^{\otimes n},$ i.e. $U |0\rangle^{\otimes n} = V |0\rangle^{\otimes n}.$
At the heart of the solutions to all these scenarios lies the use of local inversions that undo the effect of the unitary, and sewing them together to form a global inversion.
Local Inversions¶
A local inversion is a unitary circuit that locally disentangles one qubit after a previous, different unitary entangled them. Let us make an explicit example in PennyLane after some boilerplate imports. Let us look at a very shallow unitary circuit $U^\text{test} = \text{CNOT}_{(0, 1)}\text{CNOT}_{(2, 3)}\text{CNOT}_{(1, 2)} H^{\otimes n}.$
import pennylane as qml
import numpy as np
import matplotlib.pyplot as plt
def U_test():
for i in range(4):
qml.Hadamard(i)
qml.CNOT((1, 2))
qml.CNOT((2, 3))
qml.CNOT((0, 1))
qml.draw_mpl(U_test)()
plt.show()

We now want to locally invert it. That is, we want to apply a second unitary
$V_0$ such that $\text{tr}_{\neq 0} \left[V_0 U (|0 \rangle \langle 0|)^{\otimes n} U^\dagger V^\dagger_0\right] = |0 \rangle \langle 0|_0,$
where we trace out all but wire 0
.
For that, we just follow the light-cone of the qubit that we want to invert
and perform the inverse operations in reverse order in $V_0.$
def V_0():
qml.CNOT((0, 1))
qml.CNOT((1, 2))
qml.Hadamard(0)
dev = qml.device("default.qubit")
@qml.qnode(dev)
def local_inversion():
U_test() # some shallow unitary circuit
V_0() # supposed to disentangle qubit 0
return qml.density_matrix(wires=[0])
print(np.allclose(local_inversion(), np.array([[1., 0.], [0., 0.]])))
True
After performing $U^\text{test}$ and the inversion of qubit $0$, $V_0,$ we find the reduced state $|0 \rangle \langle 0|$ on the correct qubit.
Local inversions are not unique and finding one is easier than finding global inversions. But constructing a global inversion from all possible local inversions is highly non-trivial (as it constitutes a variant of the quantum marginal problem). However, the circuit sewing technique introduced in 1 lets us circumvent that problem and construct a global inversion from just a single local inversion per qubit.
To do that, we construct the other local inversions in the same way as before by just following back the light-cones of the respective qubits. In general these would have to be learned, more on that later. Here we just reverse-engineer them from knowing $U^\text{test}.$
def V_1():
qml.CNOT((0, 1))
qml.CNOT((1, 2))
qml.Hadamard(1)
def V_2():
qml.CNOT((2, 3))
qml.CNOT((1, 2))
qml.Hadamard(2)
def V_3():
qml.CNOT((2, 3))
qml.CNOT((1, 2))
qml.Hadamard(3)
Circuit Sewing¶
So how does knowing local inversions $\{V_0, V_1, V_2, V_3\}$ help us with solving the original goal of finding a global inversion $U V = \mathbb{1}?$ The authors introduce a clever trick that they coin circuit sewing. It works by swapping out the decoupled qubit with an ancilla register and restoring (“repairing”) the unitary on the remaining wires. Let us walk through this process step by step.
We already saw how to decouple qubit $0$ in local_inversion()
above. We continue by swapping out
the decoupled wire with an ancilla qubit and “repairing” the circuit by applying $V^\dagger_0.$
(This is called “repairing” because $V_1$ can now decouple qubit 1, which would not be possible in general without that step.)
We label all ancilla wires by [n + 0, n + 1, .. 2n-1]
to have an easy 1-to-1 correspondence and we see that qubit 1
is successfully decoupled.
For completeness, we also check that the swapped out qubit (now moved to wire n + 0
) is decoupled still.
n = 4 # number of qubits
@qml.qnode(dev)
def sewing_1():
U_test() # some shallow unitary circuit
qml.Barrier()
V_0() # disentangle qubit 0
qml.Barrier()
qml.SWAP((0, n)) # swap out disentangled qubit 0 and n+0
qml.Barrier()
qml.adjoint(V_0)() # repair circuit from V_0
qml.Barrier()
V_1() # disentangle qubit 1
return qml.density_matrix(wires=[1]), qml.density_matrix(wires=[n])
# The Barriers are to see which part of the circuit corresponds to which gate
qml.draw_mpl(sewing_1)()
plt.show()
r1, rn = sewing_1()
print(f"Sewing qubit 1")
print(f"rho_1 = |0⟩⟨0| {np.allclose(r1, np.array([[1, 0], [0, 0]]))}")
print(f"rho_0+n = |0⟩⟨0| {np.allclose(rn, np.array([[1, 0], [0, 0]]))}")

Sewing qubit 1
rho_1 = |0⟩⟨0| True
rho_0+n = |0⟩⟨0| True
We can continue this process for all qubits. Let us be tedious and do all steps one by one.
@qml.qnode(dev)
def sewing_2():
U_test() # some shallow unitary circuit
V_0() # disentangle qubit 0
qml.SWAP((0, n)) # swap out disentangled qubit 0 and n+0
qml.adjoint(V_0)() # repair circuit from V_0
V_1() # disentangle qubit 1
qml.SWAP((1, n + 1)) # swap out disentangled qubit 1 to n+1
qml.adjoint(V_1)() # repair circuit from V_1
V_2() # disentangle qubit 2
return qml.density_matrix(wires=[2]), qml.density_matrix(wires=[n]), qml.density_matrix(wires=[n + 1])
r2, rn, rn1 = sewing_2()
print(f"Sewing qubit 2")
print(f"rho_2 = |0⟩⟨0| {np.allclose(r2, np.array([[1, 0], [0, 0]]))}")
print(f"rho_0+n = |0⟩⟨0| {np.allclose(rn, np.array([[1, 0], [0, 0]]))}")
print(f"rho_1+n = |0⟩⟨0| {np.allclose(rn1, np.array([[1, 0], [0, 0]]))}")
Sewing qubit 2
rho_2 = |0⟩⟨0| True
rho_0+n = |0⟩⟨0| True
rho_1+n = |0⟩⟨0| True
We continue to show that the swapped out wires remain decoupled, as well as the qubit we are currently decoupling.
@qml.qnode(dev)
def sewing_3():
U_test() # some shallow unitary circuit
V_0() # disentangle qubit 0
qml.SWAP((0, n)) # swap out disentangled qubit 0 and n+0
qml.adjoint(V_0)() # repair circuit from V_0
V_1() # disentangle qubit 1
qml.SWAP((1, n + 1)) # swap out disentangled qubit 1 to n+1
qml.adjoint(V_1)() # repair circuit from V_1
V_2() # disentangle qubit 2
qml.SWAP((2, n + 2)) # swap out disentangled qubit 2 to n+2
qml.adjoint(V_2)() # repair circuit from V_2
V_3() # disentangle qubit 3
return qml.density_matrix(wires=[3]), qml.density_matrix(wires=[n]), qml.density_matrix(wires=[n + 1]), qml.density_matrix(wires=[n + 2])
r3, rn, rn1, rn2 = sewing_3()
print(f"Sewing qubit 3")
print(f"rho_3 = |0⟩⟨0| {np.allclose(r3, np.array([[1, 0], [0, 0]]))}")
print(f"rho_0+n = |0⟩⟨0| {np.allclose(rn, np.array([[1, 0], [0, 0]]))}")
print(f"rho_1+n = |0⟩⟨0| {np.allclose(rn1, np.array([[1, 0], [0, 0]]))}")
print(f"rho_2+n = |0⟩⟨0| {np.allclose(rn2, np.array([[1, 0], [0, 0]]))}")
Sewing qubit 3
rho_3 = |0⟩⟨0| True
rho_0+n = |0⟩⟨0| True
rho_1+n = |0⟩⟨0| True
rho_2+n = |0⟩⟨0| True
After one final swap and repair, we arrive at a state where all original qubits are decoupled. We just need to move them back to their original position with a global SWAP. But not just that, we also now know that, globally, the original $U^\text{test}$ is inverted.
@qml.qnode(dev)
def V_dagger_test():
U_test() # some shallow unitary circuit
V_0() # disentangle qubit 0
qml.SWAP((0, n)) # swap out disentangled qubit 0 and n+0
qml.adjoint(V_0)() # repair circuit from V_0
V_1() # disentangle qubit 1
qml.SWAP((1, n + 1)) # swap out disentangled qubit 1 to n+1
qml.adjoint(V_1)() # repair circuit from V_1
V_2() # disentangle qubit 2
qml.SWAP((2, n + 2)) # swap out disentangled qubit 2 to n+2
qml.adjoint(V_2)() # repair circuit from V_2
V_3() # disentangle qubit 3
qml.SWAP((3, n + 3)) # swap out disentangled qubit 3 to n+3
qml.adjoint(V_3)() # repair circuit from V_3
for i in range(n): # swap back all decoupled wires to their original registers
qml.SWAP((i + n, i))
return qml.density_matrix([0, 1, 2, 3])
psi0 = np.eye(2**4)[0] # |0>^n
np.allclose(V_dagger_test(), np.outer(psi0, psi0))
True
Everything after U_test()
in V_dagger_test
constitutes $(V^\text{sew})^\dagger.$
def V_dagger():
V_0() # disentangle qubit 0
qml.SWAP((0, n)) # swap out disentangled qubit 0 and n+0
qml.adjoint(V_0)() # repair circuit from V_0
V_1() # disentangle qubit 1
qml.SWAP((1, n + 1)) # swap out disentangled qubit 1 to n+1
qml.adjoint(V_1)() # repair circuit from V_1
V_2() # disentangle qubit 2
qml.SWAP((2, n + 2)) # swap out disentangled qubit 2 to n+2
qml.adjoint(V_2)() # repair circuit from V_2
V_3() # disentangle qubit 3
qml.SWAP((3, n + 3)) # swap out disentangled qubit 3 to n+3
qml.adjoint(V_3)() # repair circuit from V_3
for i in range(n): # swap back all decoupled wires to their original registers
qml.SWAP((i + n, i))
It is such that the action of $U^\text{test}$ on $|0 \rangle^{\otimes n}$ is reverted when tracing out the ancilla qubits. From the paper we know that, in fact, the action of the sewn $V^\text{sew}$ overall is
$U$ acts on the first n
qubits, whereas $U^\dagger$ acts on the n
ancilla qubits.
Numerical Experiment¶
The actual learning in this procedure happens in obtaining the local inversions $\{V_0, V_1, V_2, V_3\}.$ The paper relies on existence proofs from gate synthesis 4 and suggests brute-force searching via enumerate-and-test to find suitable $\{V_i\}.$ The idea is to essentially take a big enough set of possible $\{V_i\}$ and post-select those that fulfill $||V^\dagger_i U^\dagger P_i U V_i - P_i|| < \epsilon$ for $P_i \in \{X, Y, Z\},$ which the authors show suffices as a criterium to have an approximate local inversion.
Instead of doing that, let us here look at an explicit example by constructing a target unitary of some structure and a variational Ansatz for the local inversions that has a different structure. There are many ways to obtain local inversions, this is just one we find more convenient.
First, let us construct the target unitary
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
n = 4
wires = range(n)
U_params = jax.random.normal(jax.random.PRNGKey(0), shape=(2, n//2), dtype=float)
def U_target(wires):
for i in range(n):
qml.Hadamard(wires=wires[i])
# brick-wall ansatz
for i in range(0, n, 2):
qml.IsingXX(U_params[0, i], wires=(wires[i], wires[(i+1)%len(wires)]))
for i in range(1, n, 2):
qml.IsingXX(U_params[1, i], wires=(wires[i], wires[(i+1)%len(wires)]))
qml.draw_mpl(U_target)(wires)
plt.show()

Putting on blindfolds and assuming we don’t know the circuit structure of $U^\text{target},$ we set up a variational Ansatz for the local inversions $V_i$ with the following structure.
n_layers = 2
def V_i(params, wires):
for i in range(n):
qml.RX(params[0, 0, i], i)
for i in range(n):
qml.RY(params[0, 1, i], i)
for ll in range(n_layers):
for i in range(0, n, 2):
qml.CNOT((wires[i], wires[i+1]))
for i in range(1, n, 2):
qml.CNOT((wires[i], wires[(i+1)%n]))
for i in range(n):
qml.RX(params[ll+1, 0, i], i)
for i in range(n):
qml.RY(params[ll+1, 1, i], i)
params = jax.random.normal(jax.random.PRNGKey(10), shape=(n_layers+1, 2, n), dtype=float)
qml.draw_mpl(V_i)(params, wires)
plt.show()

Next, we are going to run optimizations for each $V_i$ to find a local inversion. For that we need some boilerplate code, see our demo on optimizing quantum circuits in jax.
import optax
from datetime import datetime
from functools import partial
X, Y, Z = qml.PauliX, qml.PauliY, qml.PauliZ
def run_opt(value_and_grad, theta, n_epochs=100, lr=0.1, b1=0.9, b2=0.999):
optimizer = optax.adam(learning_rate=lr, b1=b1, b2=b2)
opt_state = optimizer.init(theta)
energy = np.zeros(n_epochs)
thetas = []
@jax.jit
def step(opt_state, theta):
val, grad_circuit = value_and_grad(theta)
updates, opt_state = optimizer.update(grad_circuit, opt_state)
theta = optax.apply_updates(theta, updates)
return opt_state, theta, val
t0 = datetime.now()
## Optimization loop
for n in range(n_epochs):
opt_state, theta, val = step(opt_state, theta)
energy[n] = val
thetas.append(
theta
)
t1 = datetime.now()
print(f"final loss: {val}; min loss: {np.min(energy)}; after {t1 - t0}")
return thetas, energy
dev = qml.device("default.qubit")
As a cost function, we perform state tomography after applying $U^\text{target}$ and our Ansatz $V_i.$
Our aim is to bring the state on qubit i
back to the north pole of the Bloch sphere, and we specify our cost function accordingly.
@qml.qnode(dev, interface="jax")
def qnode_i(params, i):
U_target(wires)
V_i(params, wires)
return [qml.expval(P(i)) for P in [X, Y, Z]]
@partial(jax.jit, static_argnums=1)
@jax.value_and_grad
def cost_i(params, i):
x, y, z = qnode_i(params, i)
return x**2 + y**2 + (1-z)**2
We can now run the optimization. We see that in that case it suffices to use the random initial values from above for each optimization.
params_i = []
for i in range(n):
cost = partial(cost_i, i=i)
thetas, _ = run_opt(cost, params)
params_i.append(thetas[-1])
final loss: 4.168309870147805e-05; min loss: 4.168309870147805e-05; after 0:00:01.706601
final loss: 4.426389047108856e-06; min loss: 4.426389047108856e-06; after 0:00:01.671405
final loss: 7.119522980849117e-06; min loss: 6.6226916886298415e-06; after 0:00:01.665516
final loss: 3.7957338559167357e-06; min loss: 3.4988574855738884e-06; after 0:00:01.658421
For consistency, we check the resulting coordinates of qubit i
on the Bloch sphere.
for i in range(n):
X_res, Y_res, Z_res = qnode_i(params_i[i], i)
print(f"Bloch sphere coordinates of qubit {i} after inversion: {X_res:.5f}, {Y_res:.5f} {Z_res:.5f}")
Bloch sphere coordinates of qubit 0 after inversion: -0.00183, -0.00021 0.99393
Bloch sphere coordinates of qubit 1 after inversion: -0.00016, 0.00015 0.99795
Bloch sphere coordinates of qubit 2 after inversion: -0.00085, 0.00013 0.99753
Bloch sphere coordinates of qubit 3 after inversion: 0.00062, 0.00044 0.99837
We see that they are all approximately inverting the circuit as the resulting state is close to $|0\rangle$ (associated with coordinates $(x, y, z) = (0, 0, 1)$). With these local inversions, we can sew together again a unitary that globally inverts the circuit.
def V_sew():
for i in range(n):
# local sewing: inversion, exchange, heal
V_i(params_i[i], range(n))
qml.SWAP((i, i+n))
qml.adjoint(V_i)(params_i[i], range(n))
# global SWAP
for i in range(n):
qml.SWAP((i, i+n))
@qml.qnode(dev, interface="jax")
def sewing_test():
U_target(range(n))
V_sew()
return qml.density_matrix(range(4))
print(np.allclose(sewing_test(), np.outer(psi0, psi0), atol=1e-1))
True
The final test confirms that $V^\text{sew}$ approximately inverts $U^\text{target}$ on the system wires.
Conclusion¶
We saw how one can construct a global inversion from sewing together local inversions. This is a powerful new technique that may find applications in different domains of quantum computing. The technique cleverly circumvents the quantum marginal problem of constructing a global inversion from local ones compatible with each other.
The authors use this technique to prove that constant depth quantum circuits are learnable (i.e. can be reconstructed) in a variety of different scenarios.
Note
We mainly focussed on the case of constructing $V^\text{sew}$ such that $V^\text{sew} U |0 \rangle^{\otimes n} = |0 \rangle^{\otimes n}$ as it already nicely captures the main technical method that is circuit sewing. This is different to learning the full unitary, i.e. $V$ such that $U V = 1.$
For this, the circuit sewing works in the exact same way. The main difference is that the local inversions are now full inversions in the sense of $\text{tr}_{\neq 0}\left[V_i U\right] = \mathbb{1}_i$ (whereas before we just had $V_i U |0 \rangle^{\otimes n} = |0 \rangle^{\otimes n},$ which is a simpler case). The authors show that a sufficient condition for full inversion is achieved by minimizing
In the paper, the authors suggest to brute-force search the whole space of possible $V_i$ and post-select those for which the distance to $P_i$ is small. The terms are evaluated by randomly sampling input (product) states $|\phi_j\rangle$ and computing expectation values of $\langle \phi_j | V^\dagger_i U^\dagger P_i U V_i |\phi_j\rangle.$ In particular, samples for all possible candidates of $V_i$ are generated. Another possibility is to perform state tomography of the single qubit states and compare that with the input state. Either way, the circuit sewing after obtaining the learned local inversions is the same as described above.
References¶
- 1(1,2,3)
-
Hsin-Yuan Huang, Yunchao Liu, Michael Broughton, Isaac Kim, Anurag Anshu, Zeph Landau, Jarrod R. McClean “Learning shallow quantum circuits” arXiv:2401.10095, 2024.
- 2
-
Sergey Bravyi, David Gosset, Robert Koenig “Quantum advantage with shallow circuits” arXiv:1704.00690, 2017.
- 3
-
Eric R. Anschuetz, Bobak T. Kiani “Beyond Barren Plateaus: Quantum Variational Algorithms Are Swamped With Traps” arXiv:2205.05786, 2022.
- 4
-
Vivek V. Shende, Stephen S. Bullock, Igor L. Markov “Synthesis of Quantum Logic Circuits” arXiv:quant-ph/0406176, 2004.
Korbinian Kottmann
Quantum simulation & open source software
Share demo