Identifying hand-written digits(MNIST) using PyTorch

Source: MNIST database(Wikipedia)

We will use the famous MNIST Handwritten Digits Databases as our training dataset.It consists of 28px by 28px grayscale images of handwritten disgits(0–9), along with labels for each image indicating which digit it represents. MNIST stands for Modified National Institute of Standards and Technology.PyTorch is an optimized tensor library for deep learning using GPUs and CPUs.

Quick Navigation

1. Brief about PyTorch

2. Working with images in PyToch(using MNIST Dataset)

3. Splitting a dataset into training, Validation and test sets

4. Creating PyTorch models with custom logic by extending the nn.Module Class

5. Interpreting model outputs as probabilities using softmax, and picking predicted labels

6. Picking a good evaluation metric(accuracy) and loss function(cross entropy) for Classification problems

7.Setting up a training loop that also evaluates the model using Validation set

8. Testing the model manually on randomly picked examples

9.Saving and loading the model checkpoints to avoid retraining from scratch

10. References

## Imports
import torch
import torchvision ## Contains some utilities for working with the image data
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
#%matplotlib inline
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn.functional as F

We will import torchvision which contains some utility functions for working with the image data. It also contain helper classes to automatically download and import the famous datasets like MNIST.

MNIST dataset has 60,000 images which can be used to train the model. There is also an additional test set of 10,000 images which can be created by passing train = False to the MNIST class.

Loading the MNIST dataset

dataset = MNIST(root = 'data/', download = True)
print(len(dataset))
image, label = dataset[10]
plt.imshow(image, cmap = 'gray')
print('Label:', label)

These images are small in size, and recognizing the digits can sometimes be hard. PyTorch doesn’t know how to work with images. We need to convert the images into tensors. We can do this by specifying a transform while creating our dataset.

PyTorch datasets allow us to specify one or more transformation function which are applied to the images as they are loaded.

torchvision.transforms contains many such predefined functions and we will use ToTensor transform to convert images into Pytorch tensors.

Loading the MNIST data with transformation applied while loading

## MNIST dataset(images and labels)
mnist_dataset = MNIST(root = 'data/', train = True, transform = transforms.ToTensor())
print(mnist_dataset)
image_tensor, label = mnist_dataset[0]
print(image_tensor.shape, label)

The image is now convert to a 28 X 28 tensor.The first dimension is used to keep track of the color channels. Since images in the MNIST dataset are grayscale, there’s just one channel. Other datasets have images with color, in that case the color channels would be 3(Red, Green, Blue).

Training and Validation Datasets

train_data, validation_data = random_split(mnist_dataset, [50000, 10000])
## Print the length of train and validation datasets
print("length of Train Datasets: ", len(train_data))
print("length of Validation Datasets: ", len(validation_data))

While building a machine learning/Deep learning models, it is common to split the dataset into 3 parts:

  1. Training set — The part of the data will be used to train the model,compute the loss and adjust the weights of the model using gradient descent.
  2. Validation set — This part of the dataset will be used to evalute the traing model, adjusting the hyperparameters and pick the best version of the model.
  3. Test set — This part of the dataset is used to final check the model predictions on the new unseen data to evaluate how well the model is performing.
batch_size = 128
train_loader = DataLoader(train_data, batch_size, shuffle = True)
val_loader = DataLoader(validation_data, batch_size, shuffle = False)

Here we will use DataLoaders to help us load the data in batches. We will use a batch size of 128. We will set shuffle = True for the training dataloader, so that the batches generated in each epoch are different, and this randomization helps in generalizing and speed up the process.

Since Validation dataloader is used only for evaluating the model, there is no need to shuffle the images.

Defining the Logistic Model

class MnistModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, num_classes)

def forward(self, xb):
xb = xb.reshape(-1, 784)
print(xb)
out = self.linear(xb)
print(out)
return(out)

model = MnistModel()
print(model.linear.weight.shape, model.linear.bias.shape)
list(model.parameters())

Inside the init constructor method, we instantiate the weights and biases using nn.Linear. Inside the forward method, which is invoked when we pass a batch of inputs to the model, we flatten out the input tensor, and then pass it into self.linear.

xb.reshape(-1, 28 * 28) indicates to PyTorch that we want a view of the xb tensor with two dimensions, where the length along the 2nd dimension is 28 * 28(i.e 784). One argument to .reshape can be set to -1(in this case the first dimension), to let PyTorch figure it out automatically based on the shape of the original tensor.

Note that the model no longer has .weight and .bias attributes(as they are now inside the .linear attribute),but it does have a .parameters method which returns a list containg the weights and bias, and can be used by a PyTorch optimizer.

for images, labels in train_loader:
outputs = model(images)
break

print('outputs shape: ', outputs.shape)
print('Sample outputs: \n', outputs[:2].data)

For each of the 100 input images, we get 10 outputs, one for each class. These outputs represent probabilities, but for the that the output row should lie between 0 to 1 and add upto 1.

For converting the output to probabilities such that it lies between 0 to 1 we use Softmax function.

What is Softmax function?

The softmax function is a function that turns a vector of K real values into a vector of K real values that sum to 1. The input values can be positive, negative, zero, or greater than one, but the softmax transforms them into values between 0 and 1, so that they can be interpreted as probabilities. If one of the inputs is small or negative, the softmax turns it into a small probability, and if an input is large, then it turns it into a large probability, but it will always remain between 0 and 1.

The softmax function is sometimes called the softargmax function, or multi-class logistic regression. This is because the softmax is a generalization of logistic regression that can be used for multi-class classification, and its formula is very similar to the sigmoid function which is used for logistic regression. The softmax function can be used in a classifier only when the classes are mutually exclusive.

Many multi-layer neural networks end in a penultimate layer which outputs real-valued scores that are not conveniently scaled and which may be difficult to work with. Here the softmax is very useful because it converts the scores to a normalized probability distribution, which can be displayed to a user or used as input to other systems. For this reason it is usual to append a softmax function as the final layer of the neural network.

## Apply softmax for each output row
probs = F.softmax(outputs, dim = 1)

## chaecking at sample probabilities
print("Sample probabilities:\n", probs[:2].data)

print("\n")
## Add up the probabilities of an output row
print("Sum: ", torch.sum(probs[0]).item())
max_probs, preds = torch.max(probs, dim = 1)
print("\n")
print(preds)
print("\n")
print(max_probs)

Evaluation Metric and Loss Function

Here we evaluate our model by finding the percentage of labels that were predicted correctly i.e. the accuracy of the predictions.

The == performas an element-wise comparision of two tensors with the same shape, and returns a tensor of the same shape,containing 0s for unequal elements, and 1s for equal elements. Passing the result to torch.sum returns the number of labels that were predicted correctly. Finally we divide by the total total number of images to get the accuracy.

def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim = 1)
return(torch.tensor(torch.sum(preds == labels).item()/ len(preds)))

print("Accuracy: ",accuracy(outputs, labels))
print("\n")
loss_fn = F.cross_entropy
print("Loss Function: ",loss_fn)
print("\n")
## Loss for the current batch
loss = loss_fn(outputs, labels)
print(loss)

While accuracy is a great way to evluate the model, it can’t be used as a loss function for optimizing our model using gradient descent in this case for the following reasons:

  • It does not take into account the actual probabilities predicted by the model,so it can’t provide sufficient feedback for increemental improvements.

Due to this reason accuracy is a great evaluation metric for classification metric ,but not a good loss function.A commonly used loss function for classification problems is the Cross Entropy

What is Cross-Entropy

Cross-entropy is commonly used to quantify the difference between two probabilities distribution. Usually the “True” distribution(the one that your machine learning algorithm is trying to match) is expressed in terms of a one-hot distribution.

Training the Model

class MnistModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(input_size, num_classes)

def forward(self, xb):
xb = xb.reshape(-1, 784)
out = self.linear(xb)
return(out)

def training_step(self, batch):
images, labels = batch
out = self(images) ## Generate predictions
loss = F.cross_entropy(out, labels) ## Calculate the loss
return(loss)

def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels)
return({'val_loss':loss, 'val_acc': acc})

def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return({'val_loss': epoch_loss.item(), 'val_acc' : epoch_acc.item()})

def epoch_end(self, epoch,result):
print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))


model = MnistModel()
def evaluate(model, val_loader):
outputs = [model.validation_step(batch) for batch in val_loader]
return(model.validation_epoch_end(outputs))

def fit(epochs, lr, model, train_loader, val_loader, opt_func = torch.optim.SGD):
history = []
optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs):

## Training Phas
for batch in train_loader:
loss = model.training_step(batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()

## Validation phase
result = evaluate(model, val_loader)
model.epoch_end(epoch, result)
history.append(result)
return(history)
## Replace these values with your result
history = [result0] + history1 + history2 + history3 + history4
accuracies = [result['val_acc'] for result in history]
plt.plot(accuracies, '-x')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.title('Accuracy Vs. No. of epochs')

Credits

1. https://jovian.ai/aakashns/03-logistic-regression

2. https://deepai.org/machine-learning-glossary-and-terms/softmax-layer

3. https://stackoverflow.com/questions/41990250/what-is-cross-entropy

4. https://en.wikipedia.org/wiki/MNIST_database

5. https://github.com/pytorch/pytorch

Data Analyst @Novartis | Researcher | Full-time Learner

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store