Multiclass margin classifier

In this tutorial, we show how to use the PyTorch interface for PennyLane to implement a multiclass variational classifier. We consider the iris database from UCI, which has 4 features and 3 classes. We use multiple one-vs-all classifiers with a margin loss (see Multiclass Linear SVM) to classify data. Each classifier is implemented on an individual variational circuit, whose architecture is inspired by Farhi and Neven (2018) as well as Schuld et al. (2018).


../_images/margin_2.png

Initial Setup

We import PennyLane, the PennyLane-provided version of NumPy, relevant torch modules, and define the constants that will be used in this tutorial.

Our feature size is 4, and we will use amplitude embedding. This means that each possible amplitude (in the computational basis) will correspond to a single feature. With 2 qubits (wires), there are 4 possible states, and as such, we can encode a feature vector of size 4.

import pennylane as qml
import torch
import numpy as np
from torch.autograd import Variable
import torch.optim as optim

num_classes = 3
margin = 0.15
feature_size = 4
batch_size = 10
lr_adam = 0.01
train_split = 0.75
num_qubits = 2
num_layers = 6
total_iterations = 100

dev = qml.device("default.qubit", wires=num_qubits)

Quantum Circuit

We first create the layer that will be repeated in our variational quantum circuits. It consists of rotation gates for each qubit, followed by entangling/CNOT gates

def layer(W):
    qml.Rot(W[0, 0], W[0, 1], W[0, 2], wires=0)
    qml.Rot(W[1, 0], W[1, 1], W[1, 2], wires=1)
    qml.CNOT(wires=[0, 1])

We now define the quantum nodes that will be used. As we are implementing our multiclass classifier as multiple one-vs-all classifiers, we will use 3 QNodes, each representing one such classifier. That is, circuit1 classifies if a sample belongs to class 1 or not, and so on. The circuit architecture for all 3 nodes are the same. We use the PyTorch interface for the QNodes. Data is embedded in each circuit using amplitude embedding:

def circuit(weights, feat=None):
    qml.templates.embeddings.AmplitudeEmbedding(feat, [0, 1], pad=0.0, normalize=True)
    for W in weights:
        layer(W)
    return qml.expval(qml.PauliZ(0))

qnode1 = qml.QNode(circuit, dev).to_torch()
qnode2 = qml.QNode(circuit, dev).to_torch()
qnode3 = qml.QNode(circuit, dev).to_torch()

The variational quantum circuit is parametrized by the weights. We use a classical bias term that is applied after processing the quantum circuit’s output. Both variational circuit weights and classical bias term are optimized.

def variational_classifier(q_circuit, params, feat):
    weights = params[0]
    bias = params[1]
    return q_circuit(weights, feat=feat) + bias

Loss Function

Implementing multiclass classifiers as a number of one-vs-all classifiers generally evokes using the margin loss. The output of the \(i\) th classifier, \(c_i\) on input \(x\) is interpreted as a score, \(s_i\) between [-1,1]. More concretely, we have:

\[s_i = c_i(x; \theta)\]

The multiclass margin loss attempts to ensure that the score for the correct class is higher than that of incorrect classes by some margin. For a sample \((x,y)\) where \(y\) denotes the class label, we can analytically express the mutliclass loss on this sample as:

\[L(x,y) = \sum_{j \ne y}{\max{\left(0, s_j - s_y + \Delta)\right)}}\]

where \(\Delta\) denotes the margin. The margin parameter is chosen as a hyperparameter. For more information, see Multiclass Linear SVM.

def multiclass_svm_loss(q_circuits, all_params, feature_vecs, true_labels):
    loss = 0
    num_samples = len(true_labels)
    for i, feature_vec in enumerate(feature_vecs):
        # Compute the score given to this sample by the classifier corresponding to the
        # true label. So for a true label of 1, get the score computed by classifer 1,
        # which distinguishes between "class 1" or "not class 1".
        s_true = variational_classifier(
            q_circuits[int(true_labels[i])],
            (all_params[0][int(true_labels[i])], all_params[1][int(true_labels[i])]),
            feature_vec,
        )
        s_true = s_true.float()
        li = 0

        # Get the scores computed for this sample by the other classifiers
        for j in range(num_classes):
            if j != int(true_labels[i]):
                s_j = variational_classifier(
                    q_circuits[j], (all_params[0][j], all_params[1][j]), feature_vec
                )
                s_j = s_j.float()
                li += torch.max(torch.zeros(1).float(), s_j - s_true + margin)
        loss += li

    return loss / num_samples

Classification Function

Next, we use the learned models to classify our samples. For a given sample, compute the score given to it by classifier \(i\), which quantifies how likely it is that this sample belongs to class \(i\). For each sample, return the class with the highest score.

def classify(q_circuits, all_params, feature_vecs, labels):
    predicted_labels = []
    for i, feature_vec in enumerate(feature_vecs):
        scores = [0, 0, 0]
        for c in range(num_classes):
            score = variational_classifier(
                q_circuits[c], (all_params[0][c], all_params[1][c]), feature_vec
            )
            scores[c] = float(score)
        pred_class = np.argmax(scores)
        predicted_labels.append(pred_class)
    return predicted_labels

def accuracy(labels, hard_predictions):
    loss = 0
    for l, p in zip(labels, hard_predictions):
        if torch.abs(l - p) < 1e-5:
            loss = loss + 1
    loss = loss / labels.shape[0]
    return loss

Data Loading and Processing

Now we load in the iris dataset and normalize the features so that the sum of the feature elements squared is 1 (\(\ell_2\) norm is 1).

def load_and_process_data():
    data = np.loadtxt("multiclass_classification/iris.csv", delimiter=",")
    X = torch.tensor(data[:, 0:feature_size])
    print("First X sample (original)  :", X[0])

    # normalize each input
    normalization = torch.sqrt(torch.sum(X ** 2, dim=1))
    X_norm = X / normalization.reshape(len(X), 1)
    print("First X sample (normalized):", X_norm[0])

    Y = torch.tensor(data[:, -1])
    return X, Y

# Create a train and test split. Use a seed for reproducability
def split_data(feature_vecs, Y):
    np.random.seed(0)
    num_data = len(Y)
    num_train = int(train_split * num_data)
    index = np.random.permutation(range(num_data))
    feat_vecs_train = feature_vecs[index[:num_train]]
    Y_train = Y[index[:num_train]]
    feat_vecs_test = feature_vecs[index[num_train:]]
    Y_test = Y[index[num_train:]]
    return feat_vecs_train, feat_vecs_test, Y_train, Y_test

Training Procedure

In the training procedure, we begin by first initializing randomly the parameters we wish to learn (variational circuit weights and classical bias). As these are the variables we wish to optimize, we set the requires_grad flag to True. We use minibatch training—the average loss for a batch of samples is computed, and the optimization step is based on this. Total training time with the default parameters is roughly 15 minutes.

def training(features, Y):
    num_data = Y.shape[0]
    feat_vecs_train, feat_vecs_test, Y_train, Y_test = split_data(features, Y)
    num_train = Y_train.shape[0]
    q_circuits = [qnode1, qnode2, qnode3]

    # Initialize the parameters
    all_weights = [
        Variable(0.1 * torch.randn(num_layers, num_qubits, 3), requires_grad=True)
        for i in range(num_classes)
    ]
    all_bias = [Variable(0.1 * torch.ones(1), requires_grad=True) for i in range(num_classes)]
    optimizer = optim.Adam(all_weights + all_bias, lr=lr_adam)
    params = (all_weights, all_bias)
    print("Num params: ", 3 * num_layers * num_qubits * 3 + 3)

    costs, train_acc, test_acc = [], [], []

    # train the variational classifier
    for it in range(total_iterations):
        batch_index = np.random.randint(0, num_train, (batch_size,))
        feat_vecs_train_batch = feat_vecs_train[batch_index]
        Y_train_batch = Y_train[batch_index]

        optimizer.zero_grad()
        curr_cost = multiclass_svm_loss(q_circuits, params, feat_vecs_train_batch, Y_train_batch)
        curr_cost.backward()
        optimizer.step()

        # Compute predictions on train and validation set
        predictions_train = classify(q_circuits, params, feat_vecs_train, Y_train)
        predictions_test = classify(q_circuits, params, feat_vecs_test, Y_test)
        acc_train = accuracy(Y_train, predictions_train)
        acc_test = accuracy(Y_test, predictions_test)

        print(
            "Iter: {:5d} | Cost: {:0.7f} | Acc train: {:0.7f} | Acc test: {:0.7f} "
            "".format(it + 1, curr_cost.item(), acc_train, acc_test)
        )

        costs.append(curr_cost.item())
        train_acc.append(acc_train)
        test_acc.append(acc_test)

    return costs, train_acc, test_acc


# We now run our training algorithm and plot the results. Note that
# for plotting, the matplotlib library is required

features, Y = load_and_process_data()
costs, train_acc, test_acc = training(features, Y)

import matplotlib.pyplot as plt

fig, ax1 = plt.subplots()
iters = np.arange(0, total_iterations, 1)
colors = ["tab:red", "tab:blue"]
ax1.set_xlabel("Iteration", fontsize=17)
ax1.set_ylabel("Cost", fontsize=17, color=colors[0])
ax1.plot(iters, costs, color=colors[0], linewidth=4)
ax1.tick_params(axis="y", labelsize=14, labelcolor=colors[0])

ax2 = ax1.twinx()
ax2.set_ylabel("Test Acc.", fontsize=17, color=colors[1])
ax2.plot(iters, test_acc, color=colors[1], linewidth=4)

ax2.tick_params(axis="x", labelsize=14)
ax2.tick_params(axis="y", labelsize=14, labelcolor=colors[1])

plt.grid(False)
plt.tight_layout()
plt.show()
tutorial multiclass classification

Out:

First X sample (original)  : tensor([5.1000, 3.5000, 1.4000, 0.2000], dtype=torch.float64)
First X sample (normalized): tensor([0.8038, 0.5516, 0.2206, 0.0315], dtype=torch.float64)
Num params:  111
Iter:     1 | Cost: 0.3277987 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:     2 | Cost: 0.3546013 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:     3 | Cost: 0.3811063 | Acc train: 0.2589286 | Acc test: 0.3157895
Iter:     4 | Cost: 0.4543785 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter:     5 | Cost: 0.3295055 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter:     6 | Cost: 0.3723806 | Acc train: 0.3214286 | Acc test: 0.3684211
Iter:     7 | Cost: 0.3338238 | Acc train: 0.3660714 | Acc test: 0.3947368
Iter:     8 | Cost: 0.2610437 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:     9 | Cost: 0.1998721 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    10 | Cost: 0.3411241 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    11 | Cost: 0.1953842 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    12 | Cost: 0.3015879 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    13 | Cost: 0.2125721 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    14 | Cost: 0.3646154 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    15 | Cost: 0.2418376 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    16 | Cost: 0.4205813 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    17 | Cost: 0.3027744 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    18 | Cost: 0.2799583 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    19 | Cost: 0.2892894 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    20 | Cost: 0.2092953 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    21 | Cost: 0.3196099 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    22 | Cost: 0.2578963 | Acc train: 0.6517857 | Acc test: 0.6315789
Iter:    23 | Cost: 0.2401980 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    24 | Cost: 0.3013025 | Acc train: 0.3928571 | Acc test: 0.3947368
Iter:    25 | Cost: 0.2767723 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    26 | Cost: 0.2512605 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    27 | Cost: 0.2658825 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    28 | Cost: 0.2444023 | Acc train: 0.3392857 | Acc test: 0.3157895
Iter:    29 | Cost: 0.2901598 | Acc train: 0.4107143 | Acc test: 0.4210526
Iter:    30 | Cost: 0.2737681 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter:    31 | Cost: 0.2528875 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter:    32 | Cost: 0.2246395 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter:    33 | Cost: 0.2712515 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter:    34 | Cost: 0.3168977 | Acc train: 0.6517857 | Acc test: 0.6842105
Iter:    35 | Cost: 0.2271881 | Acc train: 0.6250000 | Acc test: 0.6578947
Iter:    36 | Cost: 0.2259153 | Acc train: 0.5357143 | Acc test: 0.6052632
Iter:    37 | Cost: 0.1729273 | Acc train: 0.4553571 | Acc test: 0.4736842
Iter:    38 | Cost: 0.2158889 | Acc train: 0.3839286 | Acc test: 0.4473684
Iter:    39 | Cost: 0.2238514 | Acc train: 0.4553571 | Acc test: 0.5263158
Iter:    40 | Cost: 0.1908895 | Acc train: 0.4464286 | Acc test: 0.4210526
Iter:    41 | Cost: 0.1884072 | Acc train: 0.5089286 | Acc test: 0.4473684
Iter:    42 | Cost: 0.1739708 | Acc train: 0.5357143 | Acc test: 0.5263158
Iter:    43 | Cost: 0.1856364 | Acc train: 0.5982143 | Acc test: 0.5789474
Iter:    44 | Cost: 0.1756145 | Acc train: 0.6339286 | Acc test: 0.6315789
Iter:    45 | Cost: 0.2225707 | Acc train: 0.6607143 | Acc test: 0.6315789
Iter:    46 | Cost: 0.1846925 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    47 | Cost: 0.0820714 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    48 | Cost: 0.1157356 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    49 | Cost: 0.2574283 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    50 | Cost: 0.1717379 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    51 | Cost: 0.1648125 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    52 | Cost: 0.2141991 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    53 | Cost: 0.1550635 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    54 | Cost: 0.1496530 | Acc train: 0.6875000 | Acc test: 0.6578947
Iter:    55 | Cost: 0.1631034 | Acc train: 0.8392857 | Acc test: 0.8684211
Iter:    56 | Cost: 0.1433549 | Acc train: 0.6696429 | Acc test: 0.6842105
Iter:    57 | Cost: 0.1269250 | Acc train: 0.6607143 | Acc test: 0.6842105
Iter:    58 | Cost: 0.1704179 | Acc train: 0.6339286 | Acc test: 0.6842105
Iter:    59 | Cost: 0.1486499 | Acc train: 0.6160714 | Acc test: 0.6842105
Iter:    60 | Cost: 0.1549031 | Acc train: 0.6071429 | Acc test: 0.6578947
Iter:    61 | Cost: 0.1620971 | Acc train: 0.6160714 | Acc test: 0.6842105
Iter:    62 | Cost: 0.1493890 | Acc train: 0.6428571 | Acc test: 0.6842105
Iter:    63 | Cost: 0.1229829 | Acc train: 0.6785714 | Acc test: 0.6842105
Iter:    64 | Cost: 0.1526003 | Acc train: 0.7589286 | Acc test: 0.7894737
Iter:    65 | Cost: 0.1286933 | Acc train: 0.9107143 | Acc test: 0.8421053
Iter:    66 | Cost: 0.1236174 | Acc train: 0.8660714 | Acc test: 0.9473684
Iter:    67 | Cost: 0.0841795 | Acc train: 0.7053571 | Acc test: 0.7368421
Iter:    68 | Cost: 0.0965735 | Acc train: 0.6785714 | Acc test: 0.6578947
Iter:    69 | Cost: 0.0742856 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    70 | Cost: 0.1114441 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    71 | Cost: 0.0363118 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    72 | Cost: 0.1416280 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    73 | Cost: 0.1509206 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    74 | Cost: 0.1117348 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    75 | Cost: 0.1382995 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    76 | Cost: 0.1019208 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    77 | Cost: 0.0437695 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    78 | Cost: 0.0688111 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    79 | Cost: 0.0407074 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    80 | Cost: 0.1088503 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    81 | Cost: 0.0893265 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    82 | Cost: 0.0555995 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    83 | Cost: 0.1319041 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    84 | Cost: 0.0716266 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    85 | Cost: 0.0415196 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    86 | Cost: 0.1833309 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    87 | Cost: 0.0758161 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    88 | Cost: 0.0603824 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    89 | Cost: 0.0855087 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    90 | Cost: 0.0360422 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    91 | Cost: 0.0951684 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    92 | Cost: 0.0567321 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    93 | Cost: 0.0744555 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    94 | Cost: 0.0895449 | Acc train: 0.6785714 | Acc test: 0.6315789
Iter:    95 | Cost: 0.1014840 | Acc train: 0.7410714 | Acc test: 0.7894737
Iter:    96 | Cost: 0.1029355 | Acc train: 0.8928571 | Acc test: 0.9210526
Iter:    97 | Cost: 0.0995903 | Acc train: 0.9107143 | Acc test: 0.9210526
Iter:    98 | Cost: 0.1036247 | Acc train: 0.9464286 | Acc test: 0.9473684
Iter:    99 | Cost: 0.0783416 | Acc train: 0.9196429 | Acc test: 0.9210526
Iter:   100 | Cost: 0.0601794 | Acc train: 0.9107143 | Acc test: 0.9473684

Total running time of the script: ( 18 minutes 30.245 seconds)

Gallery generated by Sphinx-Gallery