Identifying hand-written digits(MNIST) using PyTorch

Image for post
Image for post
Source: MNIST database(Wikipedia)

Quick Navigation

1. Brief about PyTorch

2. Working with images in PyToch(using MNIST Dataset)

3. Splitting a dataset into training, Validation and test sets

4. Creating PyTorch models with custom logic by extending the nn.Module Class

5. Interpreting model outputs as probabilities using softmax, and picking predicted labels

6. Picking a good evaluation metric(accuracy) and loss function(cross entropy) for Classification problems

7.Setting up a training loop that also evaluates the model using Validation set

8. Testing the model manually on randomly picked examples

9.Saving and loading the model checkpoints to avoid retraining from scratch

10. References

## Imports
import torch
import torchvision ## Contains some utilities for working with the image data
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt
#%matplotlib inline
import torchvision.transforms as transforms
from import random_split
from import DataLoader
import torch.nn.functional as F

Loading the MNIST dataset

dataset = MNIST(root = 'data/', download = True)
image, label = dataset[10]
plt.imshow(image, cmap = 'gray')
print('Label:', label)
Image for post
Image for post

Loading the MNIST data with transformation applied while loading

## MNIST dataset(images and labels)
mnist_dataset = MNIST(root = 'data/', train = True, transform = transforms.ToTensor())
image_tensor, label = mnist_dataset[0]
print(image_tensor.shape, label)

Training and Validation Datasets

train_data, validation_data = random_split(mnist_dataset, [50000, 10000])
## Print the length of train and validation datasets
print("length of Train Datasets: ", len(train_data))
print("length of Validation Datasets: ", len(validation_data))
batch_size = 128
train_loader = DataLoader(train_data, batch_size, shuffle = True)
val_loader = DataLoader(validation_data, batch_size, shuffle = False)

Defining the Logistic Model

class MnistModel(nn.Module):
def __init__(self):
self.linear = nn.Linear(input_size, num_classes)

def forward(self, xb):
xb = xb.reshape(-1, 784)
out = self.linear(xb)

model = MnistModel()
print(model.linear.weight.shape, model.linear.bias.shape)
for images, labels in train_loader:
outputs = model(images)

print('outputs shape: ', outputs.shape)
print('Sample outputs: \n', outputs[:2].data)

What is Softmax function?

The softmax function is a function that turns a vector of K real values into a vector of K real values that sum to 1. The input values can be positive, negative, zero, or greater than one, but the softmax transforms them into values between 0 and 1, so that they can be interpreted as probabilities. If one of the inputs is small or negative, the softmax turns it into a small probability, and if an input is large, then it turns it into a large probability, but it will always remain between 0 and 1.

The softmax function is sometimes called the softargmax function, or multi-class logistic regression. This is because the softmax is a generalization of logistic regression that can be used for multi-class classification, and its formula is very similar to the sigmoid function which is used for logistic regression. The softmax function can be used in a classifier only when the classes are mutually exclusive.

Many multi-layer neural networks end in a penultimate layer which outputs real-valued scores that are not conveniently scaled and which may be difficult to work with. Here the softmax is very useful because it converts the scores to a normalized probability distribution, which can be displayed to a user or used as input to other systems. For this reason it is usual to append a softmax function as the final layer of the neural network.

Image for post
Image for post
Image for post
Image for post
Image for post
Image for post
## Apply softmax for each output row
probs = F.softmax(outputs, dim = 1)

## chaecking at sample probabilities
print("Sample probabilities:\n", probs[:2].data)

## Add up the probabilities of an output row
print("Sum: ", torch.sum(probs[0]).item())
max_probs, preds = torch.max(probs, dim = 1)

Evaluation Metric and Loss Function

def accuracy(outputs, labels):
_, preds = torch.max(outputs, dim = 1)
return(torch.tensor(torch.sum(preds == labels).item()/ len(preds)))

print("Accuracy: ",accuracy(outputs, labels))
loss_fn = F.cross_entropy
print("Loss Function: ",loss_fn)
## Loss for the current batch
loss = loss_fn(outputs, labels)

What is Cross-Entropy

Training the Model

class MnistModel(nn.Module):
def __init__(self):
self.linear = nn.Linear(input_size, num_classes)

def forward(self, xb):
xb = xb.reshape(-1, 784)
out = self.linear(xb)

def training_step(self, batch):
images, labels = batch
out = self(images) ## Generate predictions
loss = F.cross_entropy(out, labels) ## Calculate the loss

def validation_step(self, batch):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
acc = accuracy(out, labels)
return({'val_loss':loss, 'val_acc': acc})

def validation_epoch_end(self, outputs):
batch_losses = [x['val_loss'] for x in outputs]
epoch_loss = torch.stack(batch_losses).mean()
batch_accs = [x['val_acc'] for x in outputs]
epoch_acc = torch.stack(batch_accs).mean()
return({'val_loss': epoch_loss.item(), 'val_acc' : epoch_acc.item()})

def epoch_end(self, epoch,result):
print("Epoch [{}], val_loss: {:.4f}, val_acc: {:.4f}".format(epoch, result['val_loss'], result['val_acc']))

model = MnistModel()
def evaluate(model, val_loader):
outputs = [model.validation_step(batch) for batch in val_loader]

def fit(epochs, lr, model, train_loader, val_loader, opt_func = torch.optim.SGD):
history = []
optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs):

## Training Phas
for batch in train_loader:
loss = model.training_step(batch)

## Validation phase
result = evaluate(model, val_loader)
model.epoch_end(epoch, result)
## Replace these values with your result
history = [result0] + history1 + history2 + history3 + history4
accuracies = [result['val_acc'] for result in history]
plt.plot(accuracies, '-x')
plt.title('Accuracy Vs. No. of epochs')
Image for post
Image for post







Data Analyst @Novartis | Researcher | Full-time Learner

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store