- Demos/
- Algorithms/
Multiclass margin classifier
Multiclass margin classifier
Published: April 08, 2020. Last updated: November 05, 2024.
In this tutorial, we show how to use the PyTorch interface for PennyLane to implement a multiclass variational classifier. We consider the iris database from UCI, which has 4 features and 3 classes. We use multiple one-vs-all classifiers with a margin loss (see Multiclass Linear SVM) to classify data. Each classifier is implemented on an individual variational circuit, whose architecture is inspired by Farhi and Neven (2018) as well as Schuld et al. (2018).

Initial Setup
We import PennyLane, the PennyLane-provided version of NumPy, relevant torch modules, and define the constants that will be used in this tutorial.
Our feature size is 4, and we will use amplitude embedding. This means that each possible amplitude (in the computational basis) will correspond to a single feature. With 2 qubits (wires), there are 4 possible states, and as such, we can encode a feature vector of size 4.
import pennylane as qml
import torch
import numpy as np
from torch.autograd import Variable
import torch.optim as optim
np.random.seed(0)
torch.manual_seed(0)
num_classes = 3
margin = 0.15
feature_size = 4
batch_size = 10
lr_adam = 0.01
train_split = 0.75
# the number of the required qubits is calculated from the number of features
num_qubits = int(np.ceil(np.log2(feature_size)))
num_layers = 6
total_iterations = 100
dev = qml.device("default.qubit", wires=num_qubits)
Quantum Circuit
We first create the layer that will be repeated in our variational quantum circuits. It consists of rotation gates for each qubit, followed by entangling/CNOT gates
def layer(W):
for i in range(num_qubits):
qml.Rot(W[i, 0], W[i, 1], W[i, 2], wires=i)
for j in range(num_qubits - 1):
qml.CNOT(wires=[j, j + 1])
if num_qubits >= 2:
# Apply additional CNOT to entangle the last with the first qubit
qml.CNOT(wires=[num_qubits - 1, 0])
We now define the quantum nodes that will be used. As we are implementing our
multiclass classifier as multiple one-vs-all classifiers, we will use 3 QNodes,
each representing one such classifier. That is, circuit1
classifies if a
sample belongs to class 1 or not, and so on. The circuit architecture for all
nodes are the same. We use the PyTorch interface for the QNodes.
Data is embedded in each circuit using amplitude embedding.
Note
For demonstration purposes we are using a very simple circuit here. You may find that other choices, for example more elaborate measurements, increase the power of the classifier.
def circuit(weights, feat=None):
qml.AmplitudeEmbedding(feat, range(num_qubits), pad_with=0.0, normalize=True)
for W in weights:
layer(W)
return qml.expval(qml.PauliZ(0))
qnodes = []
for iq in range(num_classes):
qnode = qml.QNode(circuit, dev, interface="torch")
qnodes.append(qnode)
The variational quantum circuit is parametrized by the weights. We use a classical bias term that is applied after processing the quantum circuit’s output. Both variational circuit weights and classical bias term are optimized.
def variational_classifier(q_circuit, params, feat):
weights = params[0]
bias = params[1]
return q_circuit(weights, feat=feat) + bias
Loss Function
Implementing multiclass classifiers as a number of one-vs-all classifiers generally evokes using the margin loss. The output of the \(i\) th classifier, \(c_i\) on input \(x\) is interpreted as a score, \(s_i\) between [-1,1]. More concretely, we have:
The multiclass margin loss attempts to ensure that the score for the correct class is higher than that of incorrect classes by some margin. For a sample \((x,y)\) where \(y\) denotes the class label, we can analytically express the mutliclass loss on this sample as:
where \(\Delta\) denotes the margin. The margin parameter is chosen as a hyperparameter. For more information, see Multiclass Linear SVM.
def multiclass_svm_loss(q_circuits, all_params, feature_vecs, true_labels):
loss = 0
num_samples = len(true_labels)
for i, feature_vec in enumerate(feature_vecs):
# Compute the score given to this sample by the classifier corresponding to the
# true label. So for a true label of 1, get the score computed by classifer 1,
# which distinguishes between "class 1" or "not class 1".
s_true = variational_classifier(
q_circuits[int(true_labels[i])],
(all_params[0][int(true_labels[i])], all_params[1][int(true_labels[i])]),
feature_vec,
)
s_true = s_true.float()
li = 0
# Get the scores computed for this sample by the other classifiers
for j in range(num_classes):
if j != int(true_labels[i]):
s_j = variational_classifier(
q_circuits[j], (all_params[0][j], all_params[1][j]), feature_vec
)
s_j = s_j.float()
li += torch.max(torch.zeros(1).float(), s_j - s_true + margin)
loss += li
return loss / num_samples
Classification Function
Next, we use the learned models to classify our samples. For a given sample, compute the score given to it by classifier \(i,\) which quantifies how likely it is that this sample belongs to class \(i.\) For each sample, return the class with the highest score.
def classify(q_circuits, all_params, feature_vecs, labels):
predicted_labels = []
for i, feature_vec in enumerate(feature_vecs):
scores = np.zeros(num_classes)
for c in range(num_classes):
score = variational_classifier(
q_circuits[c], (all_params[0][c], all_params[1][c]), feature_vec
)
scores[c] = float(score)
pred_class = np.argmax(scores)
predicted_labels.append(pred_class)
return predicted_labels
def accuracy(labels, hard_predictions):
loss = 0
for l, p in zip(labels, hard_predictions):
if torch.abs(l - p) < 1e-5:
loss = loss + 1
loss = loss / labels.shape[0]
return loss
Data Loading and Processing
Now we load in the iris dataset and normalize the features so that the sum of the feature elements squared is 1 (\(\ell_2\) norm is 1).
def load_and_process_data():
data = np.loadtxt("../_static/demonstration_assets/multiclass_classification/iris.csv", delimiter=",")
X = torch.tensor(data[:, 0:feature_size])
print("First X sample, original :", X[0])
# normalize each input
normalization = torch.sqrt(torch.sum(X ** 2, dim=1))
X_norm = X / normalization.reshape(len(X), 1)
print("First X sample, normalized:", X_norm[0])
Y = torch.tensor(data[:, -1])
return X, Y
# Create a train and test split.
def split_data(feature_vecs, Y):
num_data = len(Y)
num_train = int(train_split * num_data)
index = np.random.permutation(range(num_data))
feat_vecs_train = feature_vecs[index[:num_train]]
Y_train = Y[index[:num_train]]
feat_vecs_test = feature_vecs[index[num_train:]]
Y_test = Y[index[num_train:]]
return feat_vecs_train, feat_vecs_test, Y_train, Y_test
Training Procedure
In the training procedure, we begin by first initializing randomly the parameters
we wish to learn (variational circuit weights and classical bias). As these are
the variables we wish to optimize, we set the requires_grad
flag to True
. We use
minibatch training—the average loss for a batch of samples is computed, and the
optimization step is based on this. Total training time with the default parameters
is roughly 15 minutes.
def training(features, Y):
num_data = Y.shape[0]
feat_vecs_train, feat_vecs_test, Y_train, Y_test = split_data(features, Y)
num_train = Y_train.shape[0]
q_circuits = qnodes
# Initialize the parameters
all_weights = [
Variable(0.1 * torch.randn(num_layers, num_qubits, 3), requires_grad=True)
for i in range(num_classes)
]
all_bias = [Variable(0.1 * torch.ones(1), requires_grad=True) for i in range(num_classes)]
optimizer = optim.Adam(all_weights + all_bias, lr=lr_adam)
params = (all_weights, all_bias)
print("Num params: ", 3 * num_layers * num_qubits * 3 + 3)
costs, train_acc, test_acc = [], [], []
# train the variational classifier
for it in range(total_iterations):
batch_index = np.random.randint(0, num_train, (batch_size,))
feat_vecs_train_batch = feat_vecs_train[batch_index]
Y_train_batch = Y_train[batch_index]
optimizer.zero_grad()
curr_cost = multiclass_svm_loss(q_circuits, params, feat_vecs_train_batch, Y_train_batch)
curr_cost.backward()
optimizer.step()
# Compute predictions on train and validation set
predictions_train = classify(q_circuits, params, feat_vecs_train, Y_train)
predictions_test = classify(q_circuits, params, feat_vecs_test, Y_test)
acc_train = accuracy(Y_train, predictions_train)
acc_test = accuracy(Y_test, predictions_test)
print(
"Iter: {:5d} | Cost: {:0.7f} | Acc train: {:0.7f} | Acc test: {:0.7f} "
"".format(it + 1, curr_cost.item(), acc_train, acc_test)
)
costs.append(curr_cost.item())
train_acc.append(acc_train)
test_acc.append(acc_test)
return costs, train_acc, test_acc
# We now run our training algorithm and plot the results. Note that
# for plotting, the matplotlib library is required
features, Y = load_and_process_data()
costs, train_acc, test_acc = training(features, Y)
import matplotlib.pyplot as plt
fig, ax1 = plt.subplots()
iters = np.arange(0, total_iterations, 1)
colors = ["tab:red", "tab:blue"]
ax1.set_xlabel("Iteration", fontsize=17)
ax1.set_ylabel("Cost", fontsize=17, color=colors[0])
ax1.plot(iters, costs, color=colors[0], linewidth=4)
ax1.tick_params(axis="y", labelsize=14, labelcolor=colors[0])
ax2 = ax1.twinx()
ax2.set_ylabel("Test Acc.", fontsize=17, color=colors[1])
ax2.plot(iters, test_acc, color=colors[1], linewidth=4)
ax2.tick_params(axis="x", labelsize=14)
ax2.tick_params(axis="y", labelsize=14, labelcolor=colors[1])
plt.grid(False)
plt.tight_layout()
plt.show()

First X sample, original : tensor([5.1000, 3.5000, 1.4000, 0.2000], dtype=torch.float64)
First X sample, normalized: tensor([0.8038, 0.5516, 0.2206, 0.0315], dtype=torch.float64)
Num params: 111
Iter: 1 | Cost: 0.3475123 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter: 2 | Cost: 0.2533208 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter: 3 | Cost: 0.2943569 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter: 4 | Cost: 0.3344342 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter: 5 | Cost: 0.2200930 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter: 6 | Cost: 0.2718903 | Acc train: 0.4910714 | Acc test: 0.5789474
Iter: 7 | Cost: 0.2201053 | Acc train: 0.4821429 | Acc test: 0.4473684
Iter: 8 | Cost: 0.1825284 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 9 | Cost: 0.1096408 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 10 | Cost: 0.2361710 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 11 | Cost: 0.2656708 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 12 | Cost: 0.1090595 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 13 | Cost: 0.0562117 | Acc train: 0.6875000 | Acc test: 0.6315789
Iter: 14 | Cost: 0.1491400 | Acc train: 0.7946429 | Acc test: 0.8684211
Iter: 15 | Cost: 0.1158819 | Acc train: 0.9017857 | Acc test: 0.9473684
Iter: 16 | Cost: 0.1077839 | Acc train: 0.8839286 | Acc test: 0.8684211
Iter: 17 | Cost: 0.1172837 | Acc train: 0.7589286 | Acc test: 0.7894737
Iter: 18 | Cost: 0.1232261 | Acc train: 0.7589286 | Acc test: 0.7631579
Iter: 19 | Cost: 0.0856634 | Acc train: 0.6964286 | Acc test: 0.7105263
Iter: 20 | Cost: 0.1289215 | Acc train: 0.7142857 | Acc test: 0.7631579
Iter: 21 | Cost: 0.0514024 | Acc train: 0.7142857 | Acc test: 0.7368421
Iter: 22 | Cost: 0.0924992 | Acc train: 0.7142857 | Acc test: 0.7368421
Iter: 23 | Cost: 0.0755414 | Acc train: 0.7321429 | Acc test: 0.7631579
Iter: 24 | Cost: 0.0724914 | Acc train: 0.6964286 | Acc test: 0.7105263
Iter: 25 | Cost: 0.0919957 | Acc train: 0.6785714 | Acc test: 0.6842105
Iter: 26 | Cost: 0.1054716 | Acc train: 0.6785714 | Acc test: 0.6842105
Iter: 27 | Cost: 0.0977640 | Acc train: 0.6785714 | Acc test: 0.7105263
Iter: 28 | Cost: 0.0822038 | Acc train: 0.6875000 | Acc test: 0.7105263
Iter: 29 | Cost: 0.0748657 | Acc train: 0.6785714 | Acc test: 0.7105263
Iter: 30 | Cost: 0.0872695 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter: 31 | Cost: 0.1019797 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter: 32 | Cost: 0.0757497 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter: 33 | Cost: 0.1152469 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter: 34 | Cost: 0.1455490 | Acc train: 0.7142857 | Acc test: 0.7105263
Iter: 35 | Cost: 0.1052271 | Acc train: 0.8035714 | Acc test: 0.7894737
Iter: 36 | Cost: 0.0885720 | Acc train: 0.8928571 | Acc test: 0.8947368
Iter: 37 | Cost: 0.0849844 | Acc train: 0.9107143 | Acc test: 0.9473684
Iter: 38 | Cost: 0.0836692 | Acc train: 0.8482143 | Acc test: 0.8684211
Iter: 39 | Cost: 0.0610897 | Acc train: 0.8214286 | Acc test: 0.8684211
Iter: 40 | Cost: 0.0874969 | Acc train: 0.7678571 | Acc test: 0.8421053
Iter: 41 | Cost: 0.0446037 | Acc train: 0.6875000 | Acc test: 0.6578947
Iter: 42 | Cost: 0.0769979 | Acc train: 0.6875000 | Acc test: 0.6315789
Iter: 43 | Cost: 0.0590819 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 44 | Cost: 0.0429792 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 45 | Cost: 0.1355981 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 46 | Cost: 0.0787131 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 47 | Cost: 0.0289592 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 48 | Cost: 0.0553782 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter: 49 | Cost: 0.1169677 | Acc train: 0.6875000 | Acc test: 0.6578947
Iter: 50 | Cost: 0.0723401 | Acc train: 0.7142857 | Acc test: 0.7105263
Iter: 51 | Cost: 0.0733640 | Acc train: 0.8303571 | Acc test: 0.8684211
Iter: 52 | Cost: 0.0782777 | Acc train: 0.9196429 | Acc test: 0.9473684
Iter: 53 | Cost: 0.0851140 | Acc train: 0.9285714 | Acc test: 0.9473684
Iter: 54 | Cost: 0.0569319 | Acc train: 0.9017857 | Acc test: 0.8947368
Iter: 55 | Cost: 0.0681021 | Acc train: 0.8035714 | Acc test: 0.7894737
Iter: 56 | Cost: 0.0538335 | Acc train: 0.7142857 | Acc test: 0.7105263
Iter: 57 | Cost: 0.0801482 | Acc train: 0.6785714 | Acc test: 0.6842105
Iter: 58 | Cost: 0.1502000 | Acc train: 0.6785714 | Acc test: 0.6842105
Iter: 59 | Cost: 0.0810743 | Acc train: 0.6785714 | Acc test: 0.7105263
Iter: 60 | Cost: 0.1178780 | Acc train: 0.7142857 | Acc test: 0.7105263
Iter: 61 | Cost: 0.0912379 | Acc train: 0.8035714 | Acc test: 0.8157895
Iter: 62 | Cost: 0.0734926 | Acc train: 0.9107143 | Acc test: 0.8947368
Iter: 63 | Cost: 0.0350486 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 64 | Cost: 0.1046237 | Acc train: 0.8571429 | Acc test: 0.8947368
Iter: 65 | Cost: 0.0695619 | Acc train: 0.7946429 | Acc test: 0.8684211
Iter: 66 | Cost: 0.0852931 | Acc train: 0.7232143 | Acc test: 0.8157895
Iter: 67 | Cost: 0.0417796 | Acc train: 0.7053571 | Acc test: 0.6842105
Iter: 68 | Cost: 0.0670566 | Acc train: 0.6964286 | Acc test: 0.6842105
Iter: 69 | Cost: 0.0339224 | Acc train: 0.6964286 | Acc test: 0.6842105
Iter: 70 | Cost: 0.0543718 | Acc train: 0.6964286 | Acc test: 0.6578947
Iter: 71 | Cost: 0.0060147 | Acc train: 0.6964286 | Acc test: 0.6578947
Iter: 72 | Cost: 0.0925488 | Acc train: 0.6964286 | Acc test: 0.6578947
Iter: 73 | Cost: 0.0831473 | Acc train: 0.7232143 | Acc test: 0.7368421
Iter: 74 | Cost: 0.0676376 | Acc train: 0.7946429 | Acc test: 0.8684211
Iter: 75 | Cost: 0.0768720 | Acc train: 0.8482143 | Acc test: 0.8947368
Iter: 76 | Cost: 0.0659078 | Acc train: 0.9196429 | Acc test: 0.9210526
Iter: 77 | Cost: 0.0461330 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 78 | Cost: 0.0674305 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 79 | Cost: 0.0276815 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 80 | Cost: 0.0586605 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 81 | Cost: 0.0590447 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 82 | Cost: 0.0479531 | Acc train: 0.9375000 | Acc test: 0.9473684
Iter: 83 | Cost: 0.0823874 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 84 | Cost: 0.0480286 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 85 | Cost: 0.0391624 | Acc train: 0.9464286 | Acc test: 0.9736842
Iter: 86 | Cost: 0.0964098 | Acc train: 0.9553571 | Acc test: 0.9736842
Iter: 87 | Cost: 0.0150623 | Acc train: 0.9375000 | Acc test: 0.9210526
Iter: 88 | Cost: 0.0412913 | Acc train: 0.9196429 | Acc test: 0.9210526
Iter: 89 | Cost: 0.0758920 | Acc train: 0.9196429 | Acc test: 0.9210526
Iter: 90 | Cost: 0.0552046 | Acc train: 0.9196429 | Acc test: 0.9210526
Iter: 91 | Cost: 0.0351689 | Acc train: 0.9107143 | Acc test: 0.9210526
Iter: 92 | Cost: 0.0555152 | Acc train: 0.9017857 | Acc test: 0.9210526
Iter: 93 | Cost: 0.0339651 | Acc train: 0.8750000 | Acc test: 0.8684211
Iter: 94 | Cost: 0.0638086 | Acc train: 0.8750000 | Acc test: 0.8684211
Iter: 95 | Cost: 0.0358330 | Acc train: 0.8482143 | Acc test: 0.8947368
Iter: 96 | Cost: 0.0477648 | Acc train: 0.8035714 | Acc test: 0.8947368
Iter: 97 | Cost: 0.0946534 | Acc train: 0.8035714 | Acc test: 0.8947368
Iter: 98 | Cost: 0.0701062 | Acc train: 0.8839286 | Acc test: 0.8684211
Iter: 99 | Cost: 0.0827178 | Acc train: 0.9642857 | Acc test: 0.9473684
Iter: 100 | Cost: 0.0595838 | Acc train: 0.9642857 | Acc test: 0.9736842
About the author
Total running time of the script: (8 minutes 32.365 seconds)