Chapter 7. Semi-Supervised GAN

This chapter covers

  • The booming field of innovations based on the original GAN model
  • Semi-supervised learning and its immense practical importance
  • Semi-Supervised GANs (SGANs)
  • Implementation of an SGAN model

Congratulations—you have made it more than halfway through this book. By now, you not only have learned what GANs are and how they function, but also had an opportunity to implement two of the most canonical implementations: the original GAN that started it all and the DCGAN that laid the foundation for the bulk of the advanced GAN variants, including the Progressive GAN introduced in the previous chapter.

However, as with many fields, just when you think you are beginning to get a real hang of it, you uncover that the domain is much larger and more complex than initially thought. What might have seemed like a thorough understanding turns out to be no more than the tip of the iceberg.

GANs are no exception. Since their invention, they have remained an active area of research with countless variations added every year. An unofficial list—aptly named “The GAN Zoo” (https://github.com/hindupuravinash/the-gan-zoo)—which seeks to track all named GAN variants (GAN implementations with distinct names coined by the researchers who authored them) has grown to well over 300 at the time of this writing. However, judging from the fact that the original GAN paper has been cited more than 9,000 times to date (July 2019) and ranks among the most cited research papers in recent years in all of deep learning, the true number of GAN variations invented by the research community is likely even higher.[1] See figure 7.1.

1

According to a tracker from the Microsoft Academic (MA) search engine: http://mng.bz/qXXJ. See also “Top 20 Research Papers on Machine Learning and Deep Learning,” by Thuy T. Pham, 2017, http://mng.bz/E1eq.

Figure 7.1. This graph approximates the monthly cumulative count of unique GAN implementations published by the research community, starting from GAN’s invention in 2014 until the first few months of 2018. As the chart makes clear, the field of generative adversarial learning has been growing exponentially since its inception, and there is no end in sight to this growth in interest and popularity.

(Source: “The GAN Zoo,” by Avinash Hindupur, 2017, https://github.com/hindupuravinash/the-gan-zoo.)

This, however, is no reason to despair. Although it is impossible to cover all these GAN variants in this book, or any book for that matter, we can cover a few of the key innovations that will give you a good idea of what’s out there as well as the unique contributions each of these variations provides to the field of generative adversarial learning.

It is worth noting that not all of these named variants diverge drastically from the original GAN. Indeed, many of them are at a high level quite similar to the original model, such as the DCGAN in chapter 4. Even the many complex innovations such as the Wasserstein GAN (discussed in chapter 5) focus primarily on improving the performance and stability of the original GAN model or one similar to it.

In this and the following two chapters, we will focus on GAN variants that diverge from the original GAN not only in the architecture and underlying mathematics of their model implementations but also in their motivations and objectives. In particular, we will cover the following three GAN models:

For each of these GAN variants, you will learn about their objectives and what motivated them, their model architectures, and how their networks train and work. These topics will be covered both conceptually and through concrete examples. We will also provide tutorials with full working implementations of each of these models so that you can experience them firsthand.

So, without further ado, let’s dive in!

7.1. Introducing the Semi-Supervised GAN

Semi-supervised learning is one of the most promising areas of practical application of GANs. Unlike supervised learning, in which we need a label for every example in our dataset, and unsupervised learning, in which no labels are used, semi-supervised learning has a class label for only a small subset of the training dataset. By internalizing hidden structures in the data, semi-supervised learning strives to generalize from the small subset of labeled data points to effectively classify new, previously unseen examples. Importantly, for semi-supervised learning to work, the labeled and unlabeled data must come from the same underlying distribution.

The lack of labeled datasets is one of the main bottlenecks in machine learning research and practical applications. Although unlabeled data is abundant (the internet is a virtually limitless source of unlabeled images, videos, and text), assigning class labels to them is often prohibitively expensive, impractical, and time-consuming. It took two and a half years to hand-annotate the original 3.2 million images in the ImageNet—a database of labeled images that helped enable many of the advances in image processing and computer vision in the last decade.[2]

2

See “The Data That Transformed AI Research—and Possibly the World,” by Dave Gershgorn, 2017, http://mng.bz/DNVy.

Andrew Ng, a deep learning pioneer, Stanford professor, and former chief scientist of the Chinese internet giant Baidu, identified the enormous amounts of labeled data needed for training as the Achilles’ heel of supervised learning, which is used for the vast majority of today’s AI applications in industry.[3] One of the industries that suffers most from a lack of large labeled datasets is medicine, for which obtaining data (for example, outcomes from clinical trials) often requires great effort and expenditure, not to mention the even more important issues of ethics and privacy.[4] Accordingly, improving the ability of algorithms to learn from ever-smaller quantities of labeled examples has immense practical importance.

3

See “What Artificial Intelligence Can and Can’t Do Right Now,” by Andrew Ng, 2016, http://mng.bz/lopj.

4

See “What AI Can and Can’t Do (Yet) for Your Business,” by Michael Chui et al., 2018, http://mng.bz/BYDv.

Interestingly, semi-supervised learning may also be one of the closest machine learning analogs to the way humans learn. When schoolchildren learn to read and write, the teacher does not have to take them on a road trip to see tens of thousands of examples of letters and numbers, ask them to identify these symbols, and correct them as needed—similarly to the way a supervised learning algorithm would operate. Instead, a single set of examples is all that is needed for children to learn letters and numerals and then be able to recognize them regardless of font, size, angle, lighting conditions, and many other factors. Semi-supervised learning aims to teach machines in a similarly efficient manner.

Serving as a source of additional information that can be used for training, generative models proved useful in improving the accuracy of semi-supervised models. Unsurprisingly, GANs have proven the most promising. In 2016, Tim Salimans, Ian Goodfellow, and their colleagues at OpenAI achieved almost 94% accuracy on the Street View House Numbers (SVHN) benchmark dataset using only 2,000 labeled examples.[5] For comparison, the best fully supervised algorithm at the time that used labels for all 73,257 images in the SVHN training set achieved an accuracy of around 98.40%.[6] In other words, the Semi-Supervised GAN achieved overall accuracy remarkably close to the fully supervised benchmark, while using fewer than 3% of the labels for training.

5

See “Improved Techniques for Training GANs,” by Ian Goodfellow et al., 2016, https://arxiv.org/abs/1606.03498.

6

See “Densely Connected Convolutional Networks,” by Gao Huang et al., 2016, https://arxiv.org/abs/1608.06993.

Let’s find out how Salimans and his colleagues accomplished so much from so little.

7.1.1. What is a Semi-Supervised GAN?

Semi-Supervised GAN (SGAN) is a Generative Adversarial Network whose Discriminator is a multiclass classifier. Instead of distinguishing between only two classes (real and fake), it learns to distinguish between N + 1 classes, where N is the number of classes in the training dataset, with one added for the fake examples produced by the Generator.

For example, the MNIST dataset of handwritten digits has 10 labels (one label for each numeral, 0 to 9), so the SGAN Discriminator trained on this dataset would predict between 10 + 1 = 11 classes. In our implementation, the output of the SGAN Discriminator will be represented as a vector of 10 class probabilities (that sum up to 1.0) plus another probability that represents whether the image is real or fake.

Turning the Discriminator from a binary to a multiclass classifier may seem like a trivial change, but its implications are more far-reaching than may appear at first glance. Let’s start with a diagram. Figure 7.2 shows the SGAN architecture.

Figure 7.2. In this Semi-Supervised GAN, the Generator takes in a random noise vector z and produces a fake example x*. The Discriminator receives three kinds of data inputs: fake data from the Generator, real unlabeled examples x, and real labeled examples (x, y), where y is the label corresponding to the given example. The Discriminator then outputs a classification; its goal is to distinguish fake examples from the real ones and, for the real examples, identify the correct class. Notice that the portion of examples with labels is much smaller than the portion of the unlabeled data. In practice, the contrast is even starker than the one shown, with labeled data forming only a tiny fraction (often as little as 1–2%) of the training data.

As figure 7.2 indicates, the task of distinguishing between multiple classes not only impacts the Discriminator itself, but also adds complexity to the SGAN architecture, its training process, and its training objectives, as compared to the traditional GAN.

7.1.2. Architecture

The SGAN Generator’s purpose is the same as in the original GAN: it takes in a vector of random numbers and produces fake examples whose goal is to be indistinguishable from the training dataset—no change here.

The SGAN Discriminator, however, diverges considerably from the original GAN implementation. Instead of two, it receives three kinds of inputs: fake examples produced by the Generator (x*), real examples without labels from the training dataset (x), and real examples with labels from the training dataset (x, y), where y denotes the label for the given example x. Instead of binary classification, the SGAN Discriminator’s goal is to correctly categorize the input example into its corresponding class if the example is real, or reject the example as fake (which can be thought of as a special additional class).

Table 7.1 summarizes the key takeaways about the two SGAN subnetworks.

Table 7.1. SGAN Generator and Discriminator networks
 

Generator

Discriminator

Input A vector of random numbers (z) The Discriminator receives three kinds of inputs:
  • Unlabeled real examples (x) coming from the training dataset
  • Labeled real examples (x, y) coming from the training dataset
  • Fake examples (x*) produced by the Generator
Output Fake examples (x*) that strive to be as convincing as possible Probabilities, indicating the likelihood that the input example belongs either to one of the N real classes or to the fake class
Goal Generate fake examples that are indistinguishable from members of the training dataset by fooling the Discriminator into classifying them as real Learn to assign the correct class label to real examples while rejecting all examples coming from the Generator as fake

7.1.3. Training process

Recall that in a regular GAN, we train the Discriminator by computing the loss for D(x) and D(x*) and backpropagating the total loss to update the Discriminator’s trainable parameters to minimize the loss. The Generator is trained by backpropagating the Discriminator’s loss for D(x*), seeking to maximize it, so that the fake examples it synthesizes are misclassified as real.

To train the SGAN, in addition to D(x) and D(x*), we also have to compute the loss for the supervised training examples: D((x, y)). These losses correspond to the dual learning objective that the SGAN Discriminator has to grapple with: distinguishing real examples from the fake ones while also learning to classify real examples to their correct classes. Using the terminology from the original paper, these dual objectives correspond to two kinds of losses: the supervised loss and the unsupervised loss.[7]

7

See “Improved Techniques for Training GANs,” by Tim Salimans et al., 2016, https://arxiv.org/abs/1606.03498.

7.1.4. Training objective

All the GAN variants you have seen so far are generative models. Their goal is to produce realistic-looking data samples; hence, the Generator network has been of primary interest. The main purpose of the Discriminator network has been to help the Generator improve the quality of images it produces. At the end of the training, we often disregard the Discriminator and use only the fully trained Generator to create realistic-looking synthetic data.

In contrast, in a SGAN, we care primarily about the Discriminator. The goal of the training process is to make this network into a semi-supervised classifier whose accuracy is as close as possible to a fully supervised classifier (one that has labels available for each example in the training dataset), while using only a small fraction of the labels. The Generator’s goal is to aid this process by serving as a source of additional information (the fake data it produces) that helps the Generator learn the relevant patterns in the data, enhancing its classification accuracy. At the end of the training, the Generator gets discarded, and we use the trained Discriminator as a classifier.

Now that you’ve learned what motivated the creation of the SGAN and we’ve explained how the model works, it is time to see the model in action by implementing one.

7.2. Tutorial: Implementing a Semi-Supervised GAN

In this tutorial, we implement an SGAN model that learns to classify handwritten digits in the MNIST dataset by using only 100 training examples. At the end of the tutorial, we compare the model’s classification accuracy to an equivalent fully supervised model to see for ourselves the improvement achieved by semi-supervised learning.

7.2.1. Architecture diagram

Figure 7.3 shows a high-level diagram of the SGAN model implemented in this tutorial. It is a bit more complex than the general, conceptual diagram we introduced at the beginning of this chapter. After all, the devil is in the (implementation) details.

Figure 7.3. This SGAN diagram is a high-level illustration of the SGAN we implement in this chapter’s tutorial. The Generator turns random noise into fake examples. The Discriminator receives real images with labels (x, y), real images without labels (x), and fake images produced by the Generator (x*). To distinguish real examples from fake ones, the Discriminator uses the sigmoid function. To distinguish between the real classes, the Discriminator uses the softmax function.

To solve the multiclass classification problem of distinguishing between the real labels, the Discriminator uses the softmax function, which gives probability distribution over a specified number of classes—in our case, 10. The higher the probability assigned to a given label, the more confident the Discriminator is that the example belongs to the given class. To compute the classification error, we use cross-entropy loss, which measures the difference between the output probabilities and the target, one-hot-encoded labels.

To output the real-versus-fake probability, the Discriminator uses the sigmoid activation function and trains its parameters by backpropagating the binary cross-entropy loss—the same as the GANs we implemented in chapters 3 and 4.

7.2.2. Implementation

As you may notice, much of our SGAN implementation is adapted from the DCGAN model from chapter 4. This is not out of laziness (well, maybe a little . . .), but rather so that you can better see the distinct modifications needed for SGAN without any distractions from implementation details in unrelated parts of the network.

A Jupyter notebook with the full implementation, including added visualizations of the training progress, is available in our GitHub repository (https://github.com/GANs-in-Action/gans-in-action), under the chapter-7 folder. The code was tested with Python 3.6.0, Keras 2.1.6, and TensorFlow 1.8.0. To speed up the training time, we recommend running the model on a GPU.

7.2.3. Setup

As usual, we start off by importing all the modules and libraries needed to run the model, as shown in the following listing.

Listing 7.1. Import statements
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from keras import backend as K

from keras.datasets import mnist
from keras.layers import (Activation, BatchNormalization, Concatenate, Dense,
                          Dropout, Flatten, Input, Lambda, Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizers import Adam
from keras.utils import to_categorical

We also specify the input image size, the size of the noise vector z, and the number of the real classes for the semi-supervised classification (one for each numeral our Discriminator will learn to identify), as shown in the following listing.

Listing 7.2. Model input dimensions
img_rows = 28
img_cols = 28
channels = 1

img_shape = (img_rows, img_cols, channels)    1

z_dim = 100                                   2

num_classes = 10                              3

  • 1 Input image dimensions
  • 2 Size of the noise vector, used as input to the Generator
  • 3 Number of classes in the dataset

7.2.4. The dataset

Although the MNIST training dataset has 50,000 labeled training images, we will use only a small fraction of them (specified by the num_labeled parameter) for training and pretend that all the remaining ones are unlabeled. We accomplish this by sampling only from the first num_labeled images when generating batches of labeled data and from the remaining (50,000 – num_labeled) images when generating batches of unlabeled examples.

The Dataset object (shown in listing 7.3) also provides a function to return all the num_labeled training examples along with their labels as well as a function to return all 10,000 labeled test images in the MNIST dataset. After training, we will use the test set to evaluate how well the model’s classifications generalize to previously unseen examples.

Listing 7.3. Dataset for training and testing
class Dataset:
    def __init__(self, num_labeled):

        self.num_labeled = num_labeled                                   1

        (self.x_train, self.y_train), (self.x_test,                      2
                                       self.y_test) = mnist.load_data()

        def preprocess_imgs(x):
            x = (x.astype(np.float32) - 127.5) / 127.5                   3
            x = np.expand_dims(x, axis=3)                                4
            return x

        def preprocess_labels(y):
            return y.reshape(-1, 1)

        self.x_train = preprocess_imgs(self.x_train)                     5
        self.y_train = preprocess_labels(self.y_train)

        self.x_test = preprocess_imgs(self.x_test)                       6
        self.y_test = preprocess_labels(self.y_test)

    def batch_labeled(self, batch_size):
        idx = np.random.randint(0, self.num_labeled, batch_size)         7
        imgs = self.x_train[idx]
        labels = self.y_train[idx]
        return imgs, labels

    def batch_unlabeled(self, batch_size):
        idx = np.random.randint(self.num_labeled, self.x_train.shape[0], 8
                                batch_size)
        imgs = self.x_train[idx]
        return imgs

    def training_set(self):
        x_train = self.x_train[range(self.num_labeled)]
        y_train = self.y_train[range(self.num_labeled)]
        return x_train, y_train

    def test_set(self):
        return self.x_test, self.y_test

  • 1 Number of labeled examples to use for training
  • 2 Loads the MNIST dataset
  • 3 Rescales [0, 255] grayscale pixel values to [–1, 1]
  • 4 Expands image dimensions to width × height × channels
  • 5 Training data
  • 6 Testing data
  • 7 Gets a random batch of labeled images and their labels
  • 8 Gets a random batch of unlabeled images

In this tutorial, we will pretend that we have only 100 labeled MNIST images for training:

num_labeled = 100             1

dataset = Dataset(num_labeled)

  • 1 Number of labeled examples to use (the rest will be used as unlabeled)

7.2.5. The Generator

The Generator network is the same as the one we implemented for the DCGAN in chapter 4. Using transposed convolution layers, the Generator transforms the input random noise vector into 28 × 28 × 1 image; see the following listing.

Listing 7.4. SGAN Generator
def build_generator(z_dim):

    model = Sequential()
    model.add(Dense(256 * 7 * 7, input_dim=z_dim))                           1
    model.add(Reshape((7, 7, 256)))

    model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))2

    model.add(BatchNormalization())                                          3

    model.add(LeakyReLU(alpha=0.01))                                         4

    model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same')) 5

    model.add(BatchNormalization())                                          3

    model.add(LeakyReLU(alpha=0.01))                                         4

    model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'))  6

    model.add(Activation('tanh'))                                            7

    return model

  • 1 Reshapes input into a 7 × 7 × 256 tensor via a fully connected layer
  • 2 Transposed convolution layer, from 7 × 7 × 256 to 14 × 14 × 128 tensor
  • 3 Batch normalization
  • 4 Leaky ReLU activation
  • 5 Transposed convolution layer, from 14 × 14 × 128 to 14 × 14 × 64 tensor
  • 6 Transposed convolution layer, from 14 × 14 × 64 to 28 × 28 × 1 tensor
  • 7 Output layer with tanh activation

7.2.6. The Discriminator

The Discriminator is the most complex part of the SGAN model. Recall that the SGAN Discriminator has a dual objective:

  • Distinguish real examples from fake ones. For this, the SGAN Discriminator uses the sigmoid function, outputting a single output probability for binary classification.
  • For the real examples, accurately classify their label. For this, the SGAN Discriminator uses the softmax function, outputting a vector of probabilities, one for each of the target classes.
The Core Discriminator Network

We start by defining the core Discriminator network. As you may notice, the model in listing 7.5 is similar to the ConvNet-based Discriminator we implemented in chapter 4; in fact, it is exactly the same all the way until the 3 × 3 × 128 convolutional layer, its batch normalization, and Leaky ReLU activation.

After that layer, we add a dropout, a regularization technique that helps prevent overfitting by randomly dropping neurons and their connections from the neural network during training.[8] This forces the remaining neurons to reduce their codependence and develop a more general representation of the underlying data. The fraction of the neurons to be randomly dropped is specified by the rate parameter, which is set to 0.5 in our implementation: model.add(Dropout(0.5)). We add dropout because of the increased complexity of the SGAN classification task and to improve the model’s ability to generalize from only 100 labeled examples.

8

See “Improving Neural Networks by Preventing Co-Adaptation of Feature Detectors,” by Geoffrey E. Hinton et al., 2012, https://arxiv.org/abs/1207.0580. See also “Dropout: A Simple Way to Prevent Neural Networks from Overfitting,” by Nitish Srivastava et al., 2014, Journal of Machine Learning Research 15, 1929–1958.

Listing 7.5. SGAN Discriminator
def build_discriminator_net(img_shape):

    model = Sequential()

    model.add(                                 1
        Conv2D(32,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))

    model.add(LeakyReLU(alpha=0.01))           2

    model.add(                                 3
        Conv2D(64,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))

    model.add(BatchNormalization())            4

    model.add(LeakyReLU(alpha=0.01))           5

    model.add(                                 6
        Conv2D(128,
               kernel_size=3,
               strides=2,
               input_shape=img_shape,
               padding='same'))

    model.add(BatchNormalization())            4

    model.add(LeakyReLU(alpha=0.01))           5

    model.add(Dropout(0.5))                    7

    model.add(Flatten())                       8

    model.add(Dense(num_classes))              9

    return model

  • 1 Convolutional layer, from 28 × 28 × 1 into 14 × 14 × 32 tensor
  • 2 Leaky ReLU activation
  • 3 Convolutional layer, from 14 × 14 × 32 into 7 × 7 × 64 tensor
  • 4 Batch normalization
  • 5 Leaky ReLU activation
  • 6 Convolutional layer, from 7 × 7 × 64 tensor into 3 × 3 × 128 tensor
  • 7 Dropout
  • 8 Flattens the tensor
  • 9 Fully connected layer with num_classes neurons

Note that the dropout layer is added after batch normalization and not the other way around; this has shown to have superior performance due to the interplay between the two techniques.[9]

9

See “Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift,” by Xiang Li et al., 2018, https://arxiv.org/abs/1801.05134.

Also, notice that the preceding network ends with a fully connected layer with 10 neurons. Next, we need to define the two Discriminator outputs computed from these neurons: one for the supervised, multiclass classification (using softmax) and the other for the unsupervised, binary classification (using sigmoid).

The supervised Discriminator

In the following listing, we take the core Discriminator network implemented previously and use it to build the supervised portion of the Discriminator model.

Listing 7.6. SGAN Discriminator: supervised
def build_discriminator_supervised(discriminator_net):

    model = Sequential()

    model.add(discriminator_net)

    model.add(Activation('softmax'))      1

    return model

  • 1 Softmax activation, outputs predicted probability distribution over the real classes
The unsupervised Discriminator

The following listing implements the unsupervised portion of the Discriminator model on top of the core Discriminator network. Notice the predict(x) function, in which we transform the output of the 10 neurons (from the core Discriminator network) into a binary, real-versus-fake prediction.

Listing 7.7. SGAN Discriminator: unsupervised
def build_discriminator_unsupervised(discriminator_net):

    model = Sequential()

    model.add(discriminator_net)

    def predict(x):
        prediction = 1.0 - (1.0 /                                          1
                            (K.sum(K.exp(x), axis=-1, keepdims=True) + 1.0))

        return prediction

    model.add(Lambda(predict))                                             2

    return model

  • 1 Transforms distribution over real classes into binary real-versus-fake probability
  • 2 Real-versus-fake output neuron defined previously

7.2.7. Building the model

Next, we build and compile the Discriminator and Generator models. Notice the use of categorical_crossentropy and binary_crossentropy loss functions for the supervised loss and the unsupervised loss, respectively.

Listing 7.8. Building the models
def build_gan(generator, discriminator):

    model = Sequential()

    model.add(generator)                                                    1
    model.add(discriminator)

    return model

discriminator_net = build_discriminator_net(img_shape)                      2

discriminator_supervised = build_discriminator_supervised(discriminator_net)3
discriminator_supervised.compile(loss='categorical_crossentropy',           3
                                 metrics=['accuracy'],                      3
                                 optimizer=Adam())                          3

discriminator_unsupervised = build_discriminator_unsupervised(              4
                                 discriminator_net)                         4
discriminator_unsupervised.compile(loss='binary_crossentropy',              4
                                   optimizer=Adam())                        4
generator = build_generator(z_dim)                                          5
discriminator_unsupervised.trainable = False                                6
gan = build_gan(generator, discriminator_unsupervised)                      7
gan.compile(loss='binary_crossentropy', optimizer=Adam())                   7

  • 1 Combined Generator + Discriminator model
  • 2 Core Discriminator network: these layers are shared during supervised and unsupervised training.
  • 3 Builds and compiles the Discriminator for supervised training
  • 4 Builds and compiles the Discriminator for unsupervised training
  • 5 Builds the Generator
  • 6 Keeps Discriminator’s parameters constant for Generator training
  • 7 Builds and compiles GAN model with fixed Discriminator to train the Generator. Note: uses Discriminator version with unsupervised output.

7.2.8. Training

The following pseudocode outlines the SGAN training algorithm.

SGAN training algorithm

For each training iteration do

  1. Train the Discriminator (supervised):

    1. Take a random mini-batch of labeled real examples (x, y).
    2. Compute D((x, y)) for the given mini-batch and backpropagate the multiclass classification loss to update θ(D) to minimize the loss.
  2. Train the Discriminator (unsupervised):

    1. Take a random mini-batch of unlabeled real examples x.
    2. Compute D(x) for the given mini-batch and backpropagate the binary classification loss to update θ(D) to minimize the loss.
    3. Take a mini-batch of random noise vectors z and generate a mini-batch of fake examples: G(z) = x*.
    4. Compute D(x*) for the given mini-batch and backpropagate the binary classification loss to update θ(D) to minimize the loss.
  3. Train the Generator:

    1. Take a mini-batch of random noise vectors z and generate a mini-batch of fake examples: G(z) = x*.
    2. Compute D(x*) for the given mini-batch and backpropagate the binary classification loss to update θ(G) to maximize the loss.

End for

The following listing implements the SGAN training algorithm.

Listing 7.9. SGAN training algorithm
supervised_losses = []
iteration_checkpoints = []
def train(iterations, batch_size, sample_interval):

    real = np.ones((batch_size, 1))                                            1

    fake = np.zeros((batch_size, 1))                                           2

    for iteration in range(iterations):


        imgs, labels = dataset.batch_labeled(batch_size)                       3

        labels = to_categorical(labels, num_classes=num_classes)               4

        imgs_unlabeled = dataset.batch_unlabeled(batch_size)                   5

        z = np.random.normal(0, 1, (batch_size, z_dim))                        6
        gen_imgs = generator.predict(z)

        d_loss_supervised,
                 accuracy = discriminator_supervised.train_on_batch(imgs, labels)   7

        d_loss_real = discriminator_unsupervised.train_on_batch(               8
            imgs_unlabeled, real)

        d_loss_fake = discriminator_unsupervised.train_on_batch(gen_imgs, fake)9

        d_loss_unsupervised = 0.5 * np.add(d_loss_real, d_loss_fake)


        z = np.random.normal(0, 1, (batch_size, z_dim))                        10
        gen_imgs = generator.predict(z)

        g_loss = gan.train_on_batch(z, np.ones((batch_size, 1)))               11

        if (iteration + 1) % sample_interval == 0:

            supervised_losses.append(d_loss_supervised)                        12
            iteration_checkpoints.append(iteration + 1)

            print(                                                             13
                "%d [D loss supervised: %.4f, acc.: %.2f%%] [D loss" +
                " unsupervised: %.4f] [G loss: %f]"
                % (iteration + 1, d_loss_supervised, 100 * accuracy,
                  (d_loss_unsupervised, g_loss))

  • 1 Labels for real images: all 1s
  • 2 Labels for fake images: all 0s
  • 3 Gets labeled examples
  • 4 One-hot-encoded labels
  • 5 Gets unlabeled examples
  • 6 Generates a batch of fake images
  • 7 Trains on real labeled examples
  • 8 Trains on real unlabeled examples
  • 9 Trains on fake examples
  • 10 Generates a batch of fake images
  • 11 Trains the Generator
  • 12 Saves the Discriminator’s supervised classification loss to be plotted after training
  • 13 Outputs training progress
Training the model

We use a smaller batch size because we have only 100 labeled examples for training. The number of iterations is determined by trial and error: we keep increasing the number until the Discriminator’s supervised loss plateaus, but not too far past that point (to reduce the risk of overfitting):

Listing 7.10. Training the model
iterations = 8000                                 1
batch_size = 32
sample_interval = 800

train(iterations, batch_size, sample_interval)    2

  • 1 Sets hyperparameters
  • 2 Trains the SGAN for the specified number of iterations
Model training and test accuracy

And now for the moment we have all been waiting for—let’s find out how our SGAN performs as a classifier. During training, we see that we achieved supervised accuracy of 100%. Although this may seem impressive, remember that we have only 100 labeled examples from which to sample for supervised training. Perhaps our model just memorized the training dataset. What matters is how well our classifier can generalize to the previously unseen data in the training set, as shown in the following listing.

Listing 7.11. Checking the accuracy
x, y = dataset.test_set()
y = to_categorical(y, num_classes=num_classes)

_, accuracy = discriminator_supervised.evaluate(x, y)      1
print("Test Accuracy: %.2f%%" % (100 * accuracy))

  • 1 Computes classification accuracy on the test set

Drum roll, please.

Our SGAN is able to accurately classify about 89% of the examples in the test set. To see how remarkable this is, let’s compare its performance to a fully supervised classifier.

7.3. Comparison to a fully supervised classifier

To make the comparison as fair as possible, we use the same network architecture for the fully supervised classifier as the one used for the supervised Discriminator training, as shown in the following listing. The idea is that this will allow us to isolate the improvement to the classifier’s ability to generalize that was achieved through the GAN-enabled semi-supervised learning.

Listing 7.12. Fully supervised classifier
mnist_classifier = build_discriminator_supervised(
                         build_discriminator_net(img_shape))     1
mnist_classifier.compile(loss='categorical_crossentropy',
                         metrics=['accuracy'],
                         optimizer=Adam())

  • 1 Fully supervised classifier with the same network architecture as the SGAN Discriminator

We train the fully supervised classifier by using the same 100 training examples we used to train our SGAN. For brevity, the training code and the code outputting the training and test accuracy are not shown here. You can find the code in our GitHub repository, in the SGAN Jupyter notebook under the chapter-7 folder.

Like the SGAN Discriminator, the fully supervised classifier achieved 100% accuracy on the training dataset. On the test set, however, it was able to correctly classify only about 70% of the examples—about a whopping 20 percentage points worse than our SGAN. Put differently, the SGAN improved the training accuracy by almost 30%!

With a lot more training data, the fully supervised classifier’s ability to generalize improves dramatically. Using the same setup and training, the fully supervised classifier with 10,000 labeled examples (100 times as many as we originally used), we achieve an accuracy of about 98%. But that would no longer be a semi-supervised setting.

7.4. Conclusion

In this chapter, we explored how GANs can be used for semi-supervised learning by teaching the Discriminator to output class labels for real examples. You saw that the SGAN-trained classifier’s ability to generalize from a small number of training examples is significantly better than a comparable, fully supervised classifier.

From a GAN innovation perspective, a key distinguishing feature of the SGAN is the use of labels for Discriminator training. You may be wondering whether labels can be leveraged for Generator training as well. Funny you should ask—that is what the GAN variant in the next chapter (Conditional GAN) is all about.

Summary

  • Semi-Supervised GAN (SGAN) is a Generative Adversarial Network whose Discriminator learns to do the following:

    • Distinguish fake examples from real ones
    • Assign the correct class label to the real examples
  • The purpose of a SGAN is to train the Discriminator into a classifier that can achieve superior classification accuracy from as few labeled examples as possible, thereby reducing the dependency of classification tasks on enormous labeled datasets.
  • In our implementation, we used softmax and multiclass cross-entropy loss for the supervised task of assigning real labels, and sigmoid and binary cross-entropy for the task of distinguishing between real and fake data.
  • We demonstrated that SGAN’s classification accuracy on the previously unseen data in the test set is far superior to a comparable fully supervised classifier trained on the same number of labeled training examples.
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset