In this implementation, we'll use CIFAR-10, which is one of the most widely used datasets for object detection. So, let's start off by defining a helper class to download and extract the CIFAR-10 dataset, if it's not already downloaded:
cifar10_batches_dir_path = 'cifar-10-batches-py'
tar_gz_filename = 'cifar-10-python.tar.gz'
class DLProgress(tqdm):
last_block = 0
def hook(self, block_num=1, block_size=1, total_size=None):
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num
if not isfile(tar_gz_filename):
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='CIFAR-10 Python Images Batches') as pbar:
urlretrieve(
'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz',
tar_gz_filename,
pbar.hook)
if not isdir(cifar10_batches_dir_path):
with tarfile.open(tar_gz_filename) as tar:
tar.extractall()
tar.close()
After downloading and extracting the CIFAR-10 dataset, you will find out that it's already split into five batches. CIFAR-10 contains images for 10 categories/classes:
- airplane
- automobile
- bird
- cat
- deer
- dog
- frog
- horse
- ship
- truck
Before we dive into building the core of the network, let's do some data analysis and preprocessing.