In this section, we will do a complete example of implementing a CNN for digit classification using the MNIST dataset. We will build a simple model of two convolution layers and fully connected layers.
Let's start off by importing the libraries that will be needed for this implementation:
%matplotlib inline import matplotlib.pyplot as plt import tensorflow as tf import numpy as np from sklearn.metrics import confusion_matrix import math
Next, we will use TensorFlow helper functions to download and preprocess the MNIST dataset as follows:
from tensorflow.examples.tutorials.mnist import input_data
mnist_data = input_data.read_data_sets('data/MNIST/', one_hot=True)
The dataset is split into three disjoint sets: training, validation, and testing. So, let's print the number of images in each set:
print("- Number of images in the training set: {}".format(len(mnist_data.train.labels)))
print("- Number of images in the test set: {}".format(len(mnist_data.test.labels)))
print("- Number of images in the validation set: {}".format(len(mnist_data.validation.labels)))
The actual labels of the images are stored in a one-hot encoding format, so we have an array of 10 values of zeros except for the index of the class that this image represents. For later use, we need to get the class numbers of the dataset as integers:
mnist_data.test.cls_integer = np.argmax(mnist_data.test.labels, axis=1)
Let's define some known variables to be used later in our implementation:
# Default size for the input monocrome images of MNIST
image_size = 28
# Each image is stored as vector of this size.
image_size_flat = image_size * image_size
# The shape of each image
image_shape = (image_size, image_size)
# All the images in the mnist dataset are stored as a monocrome with only 1 channel
num_channels = 1
# Number of classes in the MNIST dataset from 0 till 9 which is 10
num_classes = 10
Next, we need to define a helper function to plot some images from the dataset. This helper function will plot the images in a grid of nine subplots:
def plot_imgs(imgs, cls_actual, cls_predicted=None):
assert len(imgs) == len(cls_actual) == 9
# create a figure with 9 subplots to plot the images.
fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
# plot the image at the ith index
ax.imshow(imgs[i].reshape(image_shape), cmap='binary')
# labeling the images with the actual and predicted classes.
if cls_predicted is None:
xlabel = "True: {0}".format(cls_actual[i])
else:
xlabel = "True: {0}, Pred: {1}".format(cls_actual[i], cls_predicted[i])
# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
# Show the classes as the label on the x-axis.
ax.set_xlabel(xlabel)
plt.show()
Let's plot some images from the test set and see what it looks like:
# Visualizing 9 images form the test set.
imgs = mnist_data.test.images[0:9]
# getting the actual classes of these 9 images
cls_actual = mnist_data.test.cls_integer[0:9]
#plotting the images
plot_imgs(imgs=imgs, cls_actual=cls_actual)
Here is the output: