{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# This cell is added by sphinx-gallery\n# It can be customized to whatever you like\n%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Generalized parameter-shift rules\n=================================\n\n*Author: David Wierichs (Xanadu resident). Posted: 23 August 2021.*\n\nIn this demo we will look at univariate quantum functions, i.e., those\nthat depend on a single parameter. We will investigate the form such\nfunctions usually take and demonstrate how we can *reconstruct* them as\nclassical functions, capturing the full dependence on the input\nparameter. Once we have this reconstruction, we use it to compute\nanalytically exact derivatives of the quantum function. We implement\nthis in two ways: first, by using autodifferentiation on the classical\nfunction that is produced by the reconstruction, which is flexible with\nrespect to the degree of the derivative. Second, by computing the\nderivative manually, resulting in generalized parameter-shift rules for\nquantum functions that is more efficient (regarding classical cost) than\nthe autodifferentiation approach, but requires manual computations if we\nwant to access higher-order derivatives. All we will need for the demo\nis the insight that these functions are Fourier series in their\nvariable, and the reconstruction itself is a [trigonometric\ninterpolation](https://en.wikipedia.org/wiki/Trigonometric_interpolation).\n\nA full description of the reconstruction, the technical derivation of\nthe parameter-shift rules, and considerations for multivariate functions\ncan be found in the paper [General parameter-shift rules for quantum\ngradients](https://arxiv.org/abs/2107.12390) \\[\\#GenPar\\]\\_. The core\nidea to consider these quantum functions as Fourier series was first\npresented in the preprint [Calculus on parameterized quantum\ncircuits](https://arxiv.org/abs/1812.06323) \\[\\#CalcPQC\\]\\_. We will\nfollow \\[\\#GenPar\\]\\_, but there also are two preprints discussing\ngeneral parameter-shift rules: an algebraic approach in [Analytic\ngradients in variational quantum algorithms: Algebraic extensions of the\nparameter-shift rule to general unitary\ntransformations](https://arxiv.org/abs/2107.08131) \\[\\#AlgeShift\\]\\_ and\none focusing on special gates and spectral decompositions, namely\n[Generalized quantum circuit differentiation\nrules](https://arxiv.org/abs/2108.01218) \\[\\#GenDiffRules\\]\\_.\n\n|\n\n![Function reconstruction and differentiation via parameter\nshifts.](../demonstrations/general_parshift/thumbnail_genpar.png){width=\"50%\"}\n\nCost functions arising from quantum gates\n-----------------------------------------\n\nWe start our investigation by considering a cost function that arises\nfrom measuring the expectation value of an observable in a quantum\nstate, created with a parametrized quantum operation that depends on a\nsingle variational parameter $x$. That is, the state may be prepared by\nany circuit, but we will only allow a single parameter in a single\noperation to enter the circuit. For this we will use a handy gate\nstructure that allows us to tune the complexity of the operation --- and\nthus of the cost function. More concretely, we initialize a qubit\nregister in a random state $|\\psi\\rangle$ and apply a layer of Pauli-$Z$\nrotations `RZ` to all qubits, where all rotations are parametrized by\nthe *same* angle $x$. We then measure the expectation value of a random\nHermitian observable $B$ in the created state, so that our cost function\noverall has the form\n\nHere, $U(x)$ consists of a layer of `RZ` gates,\n\nLet's implement such a cost function using PennyLane. We begin with\nfunctions that generate the random initial state $|\\psi\\rangle$ and the\nrandom observable $B$ for a given number of qubits $N$ and a fixed seed:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from scipy.stats import unitary_group\nimport numpy.random as rnd\n\n\ndef random_state(N, seed):\n \"\"\"Create a random state on N qubits.\"\"\"\n states = unitary_group.rvs(2 ** N, random_state=rnd.default_rng(seed))\n return states[0]\n\n\ndef random_observable(N, seed):\n \"\"\"Create a random observable on N qubits.\"\"\"\n rnd.seed(seed)\n # Generate real and imaginary part separately and (anti-)symmetrize them for Hermiticity\n real_part, imag_part = rnd.random((2, 2 ** N, 2 ** N))\n real_part += real_part.T\n imag_part -= imag_part.T\n return real_part + 1j * imag_part"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's set up a \"cost function generator\", namely a function that\nwill create the `cost` function we discussed above, using $|\\psi\\rangle$\nas initial state and measuring the expectation value of $B$. This\ngenerator has the advantage that we can quickly create the cost function\nfor various numbers of qubits --- and therefore cost functions with\ndifferent complexity.\n\nWe will use the default qubit simulator with its JAX backend and also\nwill rely on the NumPy implementation of JAX. To obtain precise results,\nwe enable 64-bit `float` precision via the JAX config.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from jax.config import config\n\nconfig.update(\"jax_enable_x64\", True)\nimport jax\nfrom jax import numpy as np\nimport pennylane as qml\n\ndef make_cost(N, seed):\n \"\"\"Create a cost function on N qubits with N frequencies.\"\"\"\n dev = qml.device(\"default.qubit\", wires=N)\n\n @jax.jit\n @qml.qnode(dev, interface=\"jax\")\n def cost(x):\n \"\"\"Cost function on N qubits with N frequencies.\"\"\"\n qml.QubitStateVector(random_state(N, seed), wires=dev.wires)\n for w in dev.wires:\n qml.RZ(x, wires=w, id=\"x\")\n return qml.expval(qml.Hermitian(random_observable(N, seed), wires=dev.wires))\n\n return cost"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also prepare some plotting functionalities and colors:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n\n# Set a plotting range on the x-axis\nxlim = (-np.pi, np.pi)\nX = np.linspace(*xlim, 60)\n# Colors\ngreen = \"#209494\"\norange = \"#ED7D31\"\nred = \"xkcd:brick red\"\nblue = \"xkcd:cerulean\"\npink = \"xkcd:bright pink\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we took care of these preparations, let's dive right into it:\nIt can be shown \\[\\#GenPar\\]\\_ that $E(x)$ takes the form of a Fourier\nseries in the variable $x$. That is to say that\n\nHere, $a_{\\ell}$ and $b_{\\ell}$ are the *Fourier coefficients*. If you\nwould like to understand this a bit better still, have a read of\n\\~.pennylane.fourier and remember to check out the\nFourier module tutorial </demos/tutorial\\_expressivity\\_fourier\\_series>.\n\nDue to $B$ being Hermitian, $E(x)$ is a real-valued function, so only\npositive frequencies and real coefficients appear in the Fourier series\nfor $E(x)$. This is true for any number of qubits (and therefore `RZ`\ngates) we use.\n\nUsing our function `make_cost` from above, we create the cost function\nfor several numbers of qubits and store both the function and its\nevaluations on the plotting range `X`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Qubit numbers\nNs = [1, 2, 4, 5]\n# Fix a seed\nseed = 7658741\n\ncost_functions = []\nevaluated_cost = []\nfor N in Ns:\n # Generate the cost function for N qubits and evaluate it\n cost = make_cost(N, seed)\n evaluated_cost.append([cost(x) for x in X])\n cost_functions.append(cost)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at the created $E(x)$ for the various numbers of\nqubits:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Figure with multiple axes\nfig, axs = plt.subplots(1, len(Ns), figsize=(12, 2))\n\nfor ax, N, E in zip(axs, Ns, evaluated_cost):\n # Plot cost function evaluations\n ax.plot(X, E, color=green)\n # Axis and plot labels\n ax.set_title(f\"{N} qubits\")\n ax.set_xlabel(\"$x$\")\n\n_ = axs[0].set_ylabel(\"$E$\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"|\n\nIndeed we see that $E(x)$ is a periodic function whose complexity grows\nwhen increasing $N$ together with the number of `RZ` gates. To take a\nlook at the frequencies that are present in these functions, we may use\nPennyLane's \\~.pennylane.fourier module.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pennylane.fourier import spectrum\n\nspectra = []\nfor N, cost_function in zip(Ns, cost_functions):\n # Compute spectrum with respect to parameter x\n spec = spectrum(cost_function.__wrapped__)(X[0])[\"x\"]\n print(f\"For {N} qubits the spectrum is {spec}.\")\n # Store spectrum\n spectra.append([freq for freq in spec if freq>0.0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The number of positive frequencies that appear in $E(x)$ is the same as\nthe number of `RZ` gates we used in the circuit! Recall that we only\nneed to consider the positive frequencies because $E(x)$ is real-valued,\nand that we accounted for the zero-frequency contribution in the\ncoefficient $a_0$. If you are interested why the number of gates\ncoincides with the number of frequencies, check out the\nFourier module tutorial </demos/tutorial\\_expressivity\\_fourier\\_series>.\n\nBefore moving on, let's also have a look at the Fourier coefficients in\nthe functions we created:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from pennylane.fourier.visualize import bar\n\nfig, axs = plt.subplots(2, len(Ns), figsize=(12, 4.5))\nfor i, (cost_function, spec) in enumerate(zip(cost_functions, spectra)):\n # Compute the Fourier coefficients\n coeffs = qml.fourier.coefficients(cost_function, 1, len(spec)+2)\n # Show the Fourier coefficients\n bar(coeffs, 1, axs[:, i], show_freqs=True, colour_dict={\"real\": green, \"imag\": orange})\n axs[0, i].set_title(f\"{Ns[i]} qubits\")\n # Set x-axis labels\n axs[1, i].text(Ns[i] + 2, axs[1, i].get_ylim()[0], f\"Frequency\", ha=\"center\", va=\"top\")\n # Clean up y-axis labels\n if i == 0:\n _ = [axs[j, i].set_ylabel(lab) for j, lab in enumerate([\"$a_\\ell/2$\", \"$b_\\ell/2$\"])]\n else:\n _ = [axs[j, i].set_ylabel(\"\") for j in [0, 1]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We find the real (imaginary) Fourier coefficients to be\n(anti-)symmetric. This is expected because $E(x)$ is real-valued and we\nagain see why it is enough to consider positive frequencies: the\ncoefficients of the negative frequencies follow from those of the\npositive frequencies.\n\nDetermining the full dependence on $x$\n======================================\n\nNext we will show how to determine the *full* dependence of the cost\nfunction on $x$, i.e., we will *reconstruct* $E(x)$. The key idea is not\nnew: Since $E(x)$ is periodic with known, integer frequencies, we can\nreconstruct it *exactly* by using trigonometric interpolation. For this,\nwe evaluate $E$ at shifted positions $x_\\mu$. We will show the\nreconstruction both for *equidistant* and random shifts, corresponding\nto a [uniform](https://en.wikipedia.org/wiki/Discrete_Fourier_transform)\nand a\n[non-uniform](https://en.wikipedia.org/wiki/Non-uniform_discrete_Fourier_transform)\ndiscrete Fourier transform (DFT), respectively.\n\nEquidistant shifts\n------------------\n\nFor the equidistant case we can directly implement the trigonometric\ninterpolation:\n\nwhere we reformulated $E$ in the second expression using the [sinc\nfunction](https://en.wikipedia.org/wiki/Sinc_function) to enhance the\nnumerical stability. Note that we have to take care of a rescaling\nfactor of $\\pi$ between this definition of $\\operatorname{sinc}$ and the\nNumPy implementation `np.sinc`.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sinc = lambda x: np.sinc(x / np.pi)\n\ndef full_reconstruction_equ(fun, R):\n \"\"\"Reconstruct a univariate function with up to R frequencies using equidistant shifts.\"\"\"\n # Shift angles for the reconstruction\n shifts = [2 * mu * np.pi / (2 * R + 1) for mu in range(-R, R + 1)]\n # Shifted function evaluations\n evals = np.array([fun(shift) for shift in shifts])\n\n @jax.jit\n def reconstruction(x):\n \"\"\"Univariate reconstruction using equidistant shifts.\"\"\"\n kernels = np.array(\n [sinc((R + 0.5) * (x - shift)) / sinc(0.5 * (x - shift)) for shift in shifts]\n )\n return np.dot(evals, kernels)\n\n return reconstruction\n\nreconstructions_equ = list(map(full_reconstruction_equ, cost_functions, Ns))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So how is this reconstruction doing? We will plot it along with the\noriginal function $E$, mark the shifted evaluation points $x_\\mu$ (with\ncrosses), and also show its deviation from $E(x)$ (lower plots). For\nthis, a function for the whole procedure of comparing the functions\ncomes in handy, and we will reuse it further below. For convenience,\nshowing the deviation will be an optional feature controled by the\n`show_diff` keyword argument.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def compare_functions(originals, reconstructions, Ns, shifts, show_diff=True):\n \"\"\"Plot two sets of functions next to each other and show their difference (in pairs).\"\"\"\n # Prepare the axes; we need fewer axes if we don't show the deviations\n if show_diff:\n fig, axs = plt.subplots(2, len(originals), figsize=(12, 4.5))\n else:\n fig, axs = plt.subplots(1, len(originals), figsize=(12, 2))\n _axs = axs[0] if show_diff else axs\n\n # Run over the functions and reconstructions\n for i, (orig, recon, N, _shifts) in enumerate(zip(originals, reconstructions, Ns, shifts)):\n # Evaluate the original function and its reconstruction over the plotting range\n E = np.array(list(map(orig, X)))\n E_rec = np.array(list(map(recon, X)))\n # Evaluate the original function at the positions used in the reconstruction\n E_shifts = np.array(list(map(orig, _shifts)))\n\n # Show E, the reconstruction, and the shifts (top axes)\n _axs[i].plot(X, E, lw=2, color=orange)\n _axs[i].plot(X, E_rec, linestyle=\":\", lw=3, color=green)\n _axs[i].plot(_shifts, E_shifts, ls=\"\", marker=\"x\", c=red)\n # Manage plot titles and xticks\n _axs[i].set_title(f\"{N} qubits\")\n\n if show_diff:\n # [Optional] Show the reconstruction deviation (bottom axes)\n axs[1, i].plot(X, E - E_rec, color=blue)\n axs[1, i].set_xlabel(\"$x$\")\n # Hide the xticks of the top x-axes if we use the bottom axes\n _axs[i].set_xticks([])\n\n # Manage y-axis labels\n _ = _axs[0].set_ylabel(\"$E$\")\n if show_diff:\n _ = axs[1, 0].set_ylabel(\"$E-E_{rec}$\")\n\n return fig, axs\n\nequ_shifts = [[2 * mu * np.pi / (2 * N + 1) for mu in range(-N, N + 1)] for N in Ns]\nfig, axs = compare_functions(cost_functions, reconstructions_equ, Ns, equ_shifts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*It works!*\n\nNon-equidistant shifts\n======================\n\nNow let's test the reconstruction with less regular sampling points on\nwhich to evaluate $E$. This means we can no longer use the closed-form\nexpression from above, but switch to solving the set of equations\n\nwith the---now irregular---sampling points $x_\\mu$. For this, we set up\nthe matrix\n\ncollect the Fourier coefficients of $E$ into the vector\n$\\boldsymbol{W}=(a_0, \\boldsymbol{a}, \\boldsymbol{b})$, and the\nevaluations of $E$ into another vector called $\\boldsymbol{E}$ so that\n\nLet's implement this right away! We will take the function and the\nshifts $x_\\mu$ as inputs, inferring $R$ from the number of the provided\nshifts, which is $2R+1$.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def full_reconstruction_gen(fun, shifts):\n \"\"\"Reconstruct a univariate trigonometric function using arbitrary shifts.\"\"\"\n R = (len(shifts) - 1) // 2\n frequencies = np.array(list(range(1, R + 1)))\n\n # Construct the matrix C case by case\n C1 = np.ones((2 * R + 1, 1))\n C2 = np.cos(np.outer(shifts, frequencies))\n C3 = np.sin(np.outer(shifts, frequencies))\n C = np.hstack([C1, C2, C3])\n\n # Evaluate the function to reconstruct at the shifted positions\n evals = np.array(list(map(fun, shifts)))\n\n # Solve the system of linear equations by inverting C\n W = np.linalg.inv(C) @ evals\n\n # Extract the Fourier coefficients\n a0 = W[0]\n a = W[1 : R + 1]\n b = W[R + 1 :]\n\n # Construct the Fourier series\n @jax.jit\n def reconstruction(x):\n \"\"\"Univariate reconstruction based on arbitrary shifts.\"\"\"\n return a0 + np.dot(a, np.cos(frequencies * x)) + np.dot(b, np.sin(frequencies * x))\n\n return reconstruction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To see this version of the reconstruction in action, we will sample the\nshifts $x_\\mu$ at random in $[-\\pi,\\pi)$:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"shifts = [rnd.random(2 * N + 1) * 2 * np.pi - np.pi for N in Ns]\nreconstructions_gen = list(map(full_reconstruction_gen, cost_functions, shifts))\nfig, axs = compare_functions(cost_functions, reconstructions_gen, Ns, shifts)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Again, we obtain a perfect reconstruction of $E(x)$ up to numerical\nerrors. We see that the deviation from the original cost function became\nlarger than for equidistant shifts for some of the qubit numbers but it\nstill remains much smaller than any energy scale of relevance in\napplications. The reason for these larger deviations is that some\nevaluation positions $x_\\mu$ were sampled very close to each other, so\nthat inverting the matrix $C$ becomes less stable numerically.\nConceptually, we see that the reconstruction does *not* rely on\nequidistant evaluations points.\n\nDifferentiation via reconstructions\n===================================\n\nNext, we look at a modified reconstruction strategy that only obtains\nthe odd or even part of $E(x)$. This can be done by slightly modifying\nthe shifted positions at which we evaluate $E$ and the kernel functions.\n\nFrom a perspective of implementing the derivatives there are two\napproaches, differing in which parts we derive on paper and which we\nleave to the computer: In the first approach, we perform a partial\nreconstruction using the evaluations of the original cost function $E$\non the quantum computer, as detailed below. This gives us a function\nimplemented in `jax.numpy` and we may afterwards apply `jax.grad` to\nthis function and obtain the derivative function. $E(0)$ then is only\none evaluation of this function away. In the second approach, we compute\nthe derivative of the partial reconstructions *manually* and directly\nimplement the resulting shift rule that multiplies the quantum computer\nevaluations with coefficients and sums them up. This means that the\npartial reconstruction is not performed at all by the classical\ncomputer, but only was used on paper to derive the formula for the\nderivative.\n\n*Why do we look at both approaches?*, you might ask. That is because\nneither of them is better than the other for *all* applications. The\nfirst approach offers us derivatives of any order without additional\nmanual work by iteratively applying `jax.grad`, which is very\nconvenient. However, the automatic differentiation via JAX becomes\nincreasingly expensive with the order and we always reconstruct the\n*same* type of function, namely Fourier series, so that computing the\nrespective derivatives once manually and coding up the resulting\ncoefficients of the parameter-shift rule pays off in the long run. This\nis the strength of the second approach. We start with the first\napproach.\n\nAutomatically differentiated reconstructions\n--------------------------------------------\n\nWe implement the partial reconstruction method as a function; using\nPennyLane's automatic differentiation backends, this then enables us to\nobtain the derivatives at the point of interest. For odd-order\nderivatives, we use the reconstruction of the odd part, for the\neven-order derivatives that of the even part.\n\nWe make use of modified [Dirichlet\nkernels](https://en.wikipedia.org/wiki/Dirichlet_kernel)\n$\\tilde{D}_\\mu(x)$ and equidistant shifts for this. For the odd\nreconstruction we have\n\nwhich we can implement using the reformulation\n\nfor the kernel.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"shifts_odd = lambda R: [(2 * mu - 1) * np.pi / (2 * R) for mu in range(1, R + 1)]\n# Odd linear combination of Dirichlet kernels\nD_odd = lambda x, R: np.array(\n [\n (\n sinc(R * (x - shift)) / sinc(0.5 * (x - shift)) * np.cos(0.5 * (x - shift))\n - sinc(R * (x + shift)) / sinc(0.5 * (x + shift)) * np.cos(0.5 * (x + shift))\n )\n for shift in shifts_odd(R)\n ]\n)\n\n\ndef odd_reconstruction_equ(fun, R):\n \"\"\"Reconstruct the odd part of an ``R``-frequency input function via equidistant shifts.\"\"\"\n evaluations = np.array([(fun(shift) - fun(-shift)) / 2 for shift in shifts_odd(R)])\n\n @jax.jit\n def reconstruction(x):\n \"\"\"Odd reconstruction based on equidistant shifts.\"\"\"\n return np.dot(evaluations, D_odd(x, R))\n\n return reconstruction\n\n\nodd_reconstructions = list(map(odd_reconstruction_equ, cost_functions, Ns))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The even part on the other hand takes the form\n\nNote that not only the kernels $\\hat{D}_\\mu(x)$ but also the shifted\npositions $\\{x_\\mu\\}$ differ between the odd and even case.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"shifts_even = lambda R: [mu * np.pi / R for mu in range(1, R)]\n# Even linear combination of Dirichlet kernels\nD_even = lambda x, R: np.array(\n [\n (\n sinc(R * (x - shift)) / sinc(0.5 * (x - shift)) * np.cos(0.5 * (x - shift))\n + sinc(R * (x + shift)) / sinc(0.5 * (x + shift)) * np.cos(0.5 * (x + shift))\n )\n for shift in shifts_even(R)\n ]\n)\n# Special cases of even kernels\nD0 = lambda x, R: sinc(R * x) / (sinc(x / 2)) * np.cos(x / 2)\nDpi = lambda x, R: sinc(R * (x - np.pi)) / sinc((x - np.pi) / 2) * np.cos((x - np.pi) / 2)\n\n\ndef even_reconstruction_equ(fun, R):\n \"\"\"Reconstruct the even part of ``R``-frequency input function via equidistant shifts.\"\"\"\n _evaluations = np.array([(fun(shift) + fun(-shift)) / 2 for shift in shifts_even(R)])\n evaluations = np.array([fun(0), *_evaluations, fun(np.pi)])\n kernels = lambda x: np.array([D0(x, R), *D_even(x, R), Dpi(x, R)])\n\n @jax.jit\n def reconstruction(x):\n \"\"\"Even reconstruction based on equidistant shifts.\"\"\"\n return np.dot(evaluations, kernels(x))\n\n return reconstruction\n\n\neven_reconstructions = list(map(even_reconstruction_equ, cost_functions, Ns))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also set up a function that performs both partial reconstructions and\nsums the resulting functions to the full Fourier series.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def summed_reconstruction_equ(fun, R):\n \"\"\"Sum an odd and an even reconstruction into the full function.\"\"\"\n _odd_part = odd_reconstruction_equ(fun, R)\n _even_part = even_reconstruction_equ(fun, R)\n\n def reconstruction(x):\n \"\"\"Full function based on separate odd/even reconstructions.\"\"\"\n return _odd_part(x) + _even_part(x)\n\n return reconstruction\n\n\nsummed_reconstructions = list(map(summed_reconstruction_equ, cost_functions, Ns))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We show these even (blue) and odd (red) reconstructions and how they\nindeed sum to the full function (orange, dashed). We will again use the\n`compare_functions` utility from above for the comparison.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"from matplotlib.lines import Line2D\n\n# Obtain the shifts for the reconstruction of both parts\nodd_and_even_shifts = [\n (\n shifts_odd(R)\n + shifts_even(R)\n + list(-1 * np.array(shifts_odd(R)))\n + list(-1 * np.array(shifts_odd(R)))\n + [0, np.pi]\n )\n for R in Ns\n]\n\n# Show the reconstructed parts and the sums\nfig, axs = compare_functions(cost_functions, summed_reconstructions, Ns, odd_and_even_shifts)\nfor i, (odd_recon, even_recon) in enumerate(zip(odd_reconstructions, even_reconstructions)):\n # Odd part\n E_odd = np.array(list(map(odd_recon, X)))\n axs[0, i].plot(X, E_odd, color=red)\n # Even part\n E_even = np.array(list(map(even_recon, X)))\n axs[0, i].plot(X, E_even, color=blue)\n axs[0, i].set_title('')\n_ = axs[1, 0].set_ylabel(\"$E-(E_{odd}+E_{even})$\")\ncolors = [green, red, blue, orange]\nstyles = ['-', '-', '-', '--']\nhandles = [Line2D([0], [0], color=c, ls=ls, lw=1.2) for c, ls in zip(colors, styles)]\nlabels = ['Original', 'Odd reconstruction', 'Even reconstruction', 'Summed reconstruction']\n_ = fig.legend(handles, labels, bbox_to_anchor=(0.2, 0.89), loc='lower left', ncol=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! The even and odd part indeed sum to the correct function again.\nBut what did we gain?\n\nNothing, actually, for the full reconstruction! Quite the opposite, we\nspent $2R$ evaluations of $E$ on each part, that is $4R$ evaluations\noverall to obtain a description of the full function $E$! This is way\nmore than the $2R+1$ evaluations needed for the full reconstructions\nfrom the beginning.\n\nHowever, remember that we set out to compute derivatives of $E$ at $0$,\nso that for derivatives of odd/even order only the odd/even\nreconstruction is required. Using an autodifferentiation framework, e.g.\nJAX, we can easily compute such higher-order derivatives:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# An iterative function computing the ``order``th derivative of a function ``f`` with JAX\ngrad_gen = lambda f, order: grad_gen(jax.grad(f), order - 1) if order > 0 else f\n\n# Compute the first, second, and fourth derivative\nfor order, name in zip([1, 2, 4], [\"First\", \"Second\", \"4th\"]):\n recons = odd_reconstructions if order % 2 else even_reconstructions\n recon_name = \"odd \" if order % 2 else \"even\"\n cost_grads = [grad_gen(orig, order)(0.0) for orig in cost_functions]\n recon_grads = [grad_gen(recon, order)(0.0) for recon in recons]\n all_equal = (\n \"All entries match\" if np.allclose(cost_grads, recon_grads) else \"Some entries differ!\"\n )\n print(f\"{name} derivatives via jax: {all_equal}\")\n print(\"From the cost functions: \", np.round(np.array(cost_grads), 6))\n print(f\"From the {recon_name} reconstructions: \", np.round(np.array(recon_grads), 6), \"\\n\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The derivatives coincide.\n\nGeneralized parameter-shift rules\n=================================\n\nThe second method is based on the previous one. Instead of consulting\nJAX, we may compute the wanted derivative of the odd/even kernel\nfunction manually and thus derive general parameter-shift rules from\nthis. We will leave the technical derivation of these rules to the paper\n\\[\\#GenPar\\]\\_. Start with the first derivative, which certainly is used\nthe most:\n\nThis is straight-forward to implement by defining the coefficients and\nevaluating $E$ at the shifted positions $x_\\mu$:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def parameter_shift_first(fun, R):\n \"\"\"Compute the first-order derivative of a function with R frequencies at 0.\"\"\"\n shifts = (2 * np.arange(1, 2 * R + 1) - 1) * np.pi / (4 * R)\n # Classically computed coefficients\n coeffs = np.array(\n [(-1) ** mu / (4 * R * np.sin(shift) ** 2) for mu, shift in enumerate(shifts)]\n )\n # Evaluations of the cost function E(x_mu)\n evaluations = np.array(list(map(fun, 2 * shifts)))\n # Contract coefficients with evaluations\n return np.dot(coeffs, evaluations)\n\nps_der1 = list(map(parameter_shift_first, cost_functions, Ns))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The second-order derivative takes a similar form, but we have to take\ncare of the evaluation at $0$ and the corresponding coefficient\nseparately:\n\nLet's code this up, again we only get slight complications from the\nspecial evaluation at $0$:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def parameter_shift_second(fun, R):\n \"\"\"Compute the second-order derivative of a function with R frequencies at 0.\"\"\"\n shifts = np.arange(1, 2 * R) * np.pi / (2 * R)\n # Classically computed coefficients for the main sum\n _coeffs = [(-1) ** mu / (2 * np.sin(shift) ** 2) for mu, shift in enumerate(shifts)]\n # Include the coefficients for the \"special\" term E(0).\n coeffs = np.array([-(2 * R ** 2 + 1) / 6] + _coeffs)\n # Evaluate at the regularily shifted positions\n _evaluations = list(map(fun, 2 * shifts))\n # Include the \"special\" term E(0).\n evaluations = np.array([fun(0.0)] + _evaluations)\n # Contract coefficients with evaluations.\n return np.dot(coeffs, evaluations)\n\nps_der2 = list(map(parameter_shift_second, cost_functions, Ns))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will compare these two shift rules to the finite-difference\nderivative commonly used for numerical differentiation. We choose a\nfinite difference of $d_x=5\\times 10^{-5}$.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dx = 5e-5\n\ndef finite_diff_first(fun):\n \"\"\"Compute the first order finite difference derivative.\"\"\"\n return (fun(dx/2) - fun(-dx/2))/dx\n\nfd_der1 = list(map(finite_diff_first, cost_functions))\n\ndef finite_diff_second(fun):\n \"\"\"Compute the second order finite difference derivative.\"\"\"\n fun_p, fun_0, fun_m = fun(dx), fun(0.0), fun(-dx)\n return ((fun_p - fun_0)/dx - (fun_0 - fun_m)/dx) /dx\n\nfd_der2 = list(map(finite_diff_second, cost_functions))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All that is left is to compare the computed parameter-shift and\nfinite-difference derivatives:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print(\"Number of qubits/RZ gates: \", *Ns, sep=\" \" * 9)\nprint(f\"First-order parameter-shift rule: {np.round(np.array(ps_der1), 6)}\")\nprint(f\"First-order finite difference: {np.round(np.array(fd_der1), 6)}\")\nprint(f\"Second-order parameter-shift rule: {np.round(np.array(ps_der2), 6)}\")\nprint(f\"Second-order finite difference: {np.round(np.array(fd_der2), 6)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The parameter-shift rules work as expected! And we were able to save a\ncircuit evaluation as compared to a full reconstruction.\n\nAnd this is all we want to show here about univariate function\nreconstructions and generalized parameter shift rules. Note that the\ntechniques above can partially be extended to frequencies that are not\ninteger-valued, but many closed form expressions are no longer valid.\nFor the reconstruction, the approach via Dirichlet kernels no longer\nworks in the general case; instead, a system of equations has to be\nsolved, but with generalized frequencies $\\{\\Omega_\\ell\\}$ instead of\n$\\{\\ell\\}$ (see e.g. Sections III A-C in [^1])\n\nReferences\n==========\n\n[^1]: David Wierichs, Josh Izaac, Cody Wang, Cedric Yen-Yu Lin. \"General\n parameter-shift rules for quantum gradients\". [arXiv preprint\n arXiv:2107.12390](https://arxiv.org/abs/2107.12390).\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 0
}