In the previous chapter, you learned about the SGAN, which introduced you to the idea of using labels in GAN training. SGANs use labels to train the Discriminator into a powerful semi-supervised classifier. In this chapter, you’ll learn about the Conditional GAN (CGAN), which uses labels to train both the Generator and the Discriminator. Thanks to this innovation, a Conditional GAN allows us to direct the Generator to synthesize the kind of fake examples we want.
As you have seen throughout this book, GANs are capable of producing examples ranging from simple handwritten digits to photorealistic images of human faces. However, although we could control the domain of examples our GAN learned to emulate by our selection of the training dataset, we could not specify any of the characteristics of the data samples the GAN would generate. For instance, the DCGAN we implemented in chapter 4 could synthesize realistic-looking handwritten digits, but we could not control whether it would produce, say, the number 7 rather than the number 9 at any given time.
On simple datasets like the MNIST, in which examples belong to only one of 10 classes, this concern may seem trivial. If, for instance, our goal is to produce the number 9, we can just keep generating examples until we get the number we want. On more complex data-generation tasks, however, the domain of possible answers gets too large for such a brute-force solution to be practical. Take, for example, the task of generating human faces. As impressive as the images produced by the Progressive GAN from chapter 6 are, we have no control over what face will get produced. There is no way to direct the Generator to synthesize, say, a male or a female face, let alone other features such as age or facial expression.
The ability to decide what kind of data will be generated opens the door to a vast array of applications. As a somewhat contrived example, imagine that we are detectives solving a murder mystery, and a witness describes the killer as a middle-aged woman with long red hair and green eyes. It would greatly expedite the process if instead of hiring a sketch artist (who can produce only one sketch at a time), we could enter the descriptive features into a computer program and have it output a range of faces matching the criteria. Our witness then could point us to the one that resembles the criminal most closely.
We are sure you can think of many other practical applications for which the ability to generate an image that matches the criteria of our choice would be a game-changer. In medical research, we could guide the creation of new drug compounds; in filmmaking and computer-generated imagery (CGI), we could create the exact scene we want with minimal input from human animators. The list goes on.
The CGAN is one of the first GAN innovations that made targeted data generation possible, and arguably the most influential one. In the remainder of this chapter, you will learn how CGANs work and implement a small-scale version by using (you guessed it!) the MNIST dataset.
Introduced in 2014 by University of Montreal PhD student Mehdi Mirza and Flickr AI architect Simon Osindero, Conditional GAN is a generative adversarial network whose Generator and Discriminator are conditioned during training by using some additional information.[1] This auxiliary information could be, in theory, anything, such as a class label, a set of tags, or even a written description. For clarity and simplicity, we will use labels as the conditioning information as we explain how CGAN works.
See “Conditional Generative Adversarial Nets,” by Mehdi Mirza and Simon Osindero, 2014, https://arxiv.org/abs/1411.1784.
During CGAN training, the Generator learns to produce realistic examples for each label in the training dataset, and the Discriminator learns to distinguish fake example-label pairs from real example-label pairs. In contrast to the Semi-Supervised GAN from the previous chapter, whose Discriminator learns to assign the correct label to each real example (in addition to distinguishing real examples from fake ones), the Discriminator in a CGAN does not learn to identify which class is which. It learns only to accept real, matching pairs while rejecting pairs that are mismatched and pairs in which the example is fake.
For example, the CGAN Discriminator should learn to reject the pair (, 4), regardless of whether the example (handwritten numeral 3) is real or fake, because it does not match the label, 4. The CGAN Discriminator should also learn to reject all image-label pairs in which the image is fake, even if the label matches the image.
Accordingly, in order to fool the Discriminator, it is not enough for the CGAN Generator to produce realistic-looking data. The examples it generates also need to match their labels. After the Generator is fully trained, this then allows us to specify what example we want the CGAN to synthesize by passing it the desired label.
To formalize things a bit, let’s call the conditioning label y. The Generator uses the noise vector z and the label y to synthesize a fake example G(z, y) = x*|y (read as “x* given that, or conditioned on, y”). The goal of this fake example is to look (in the eyes of the Discriminator) as close as possible to a real example for the given label. Figure 8.1 illustrates the Generator.
The Discriminator receives real examples with labels (x, y), and fake examples with the label used to generate them, (x*|y, y). On the real example-label pairs, the Discriminator learns how to recognize real data and how to recognize matching pairs. On the Generator-produced examples, it learns to recognize fake image-label pairs, thereby learning to tell them apart from the real ones.
The Discriminator outputs a single probability indicating its conviction that the input is a real, matching pair. The Discriminator’s goal is to learn to reject all fake examples and all examples that fail to match their label, while accepting all real example-label pairs, as shown in figure 8.2.
The two CGAN subnetworks, their inputs, outputs, and objectives are summarized in table 8.1.
Generator |
Discriminator |
|
---|---|---|
Input | A vector of random numbers and a label: (z, y) |
The Discriminator receives the following inputs:
|
Output | Fake examples that strive to be as convincing as possible in matches for their labels: G(z, y) = x*|y | A single probability indicating whether the input example is a real, matching example-label pair |
Goal | Generate realistic-looking fake data that match their labels | Distinguish between fake example-label pairs coming from the Generator and real example-label pairs coming from the training dataset |
Putting it all together, figure 8.3 shows a high-level architecture diagram of a CGAN. Notice that for each fake example, the same label y is passed to both the Generator and the Discriminator. Also, note that the Discriminator is never explicitly trained to reject mismatched pairs by being trained on real examples with mismatching labels; its ability to identify mismatched pairs is a by-product of being trained to accept only real matching pairs.
You may have noticed a pattern: for almost every GAN variant, we present you with a table summarizing the inputs, outputs, and objectives of the Discriminator and Generator networks, and with a network architecture diagram. This is not by accident; indeed, one of the main goals of these chapters is to give you a mental template—a reusable framework of sorts—for the kind of things to look for when you encounter GAN implementations that diverge from the original GAN. Analyzing the Generator and Discriminator networks and the overall model architecture are often the best first steps.
The CGAN Discriminator receives fake labeled examples (x*|y, y) produced by the Generator and real labeled examples (x, y), and it learns to tell whether a given example-label is real or fake.
Enough for theory. It’s time we put what you have learned into practice and implement our own CGAN model.
In this tutorial, we will implement a CGAN model that learns to generate handwritten digits of our choice. At the end, we will generate a sample of images for each numeral to see how well the model learned to generate targeted data.
Our implementation is inspired by the CGAN in the open source GitHub repository of GAN models in Keras (the same one we used in chapters 3 and 4).[2] In particular, we use the repository’s approach of using Embedding layers to combine examples and labels into joint hidden representations (more on this later).
See Erik Linder-Norén’s Keras-GAN GitHub repository, 2017, https://github.com/eriklindernoren/Keras-GAN.
The rest of our CGAN model, however, diverges from the one found in the Keras-GAN repository. We refactored the embedding implementation to be more readable and added detailed explanatory comments. Crucially, we also adapted our CGAN to use convolutional neural networks, which yield significantly more realistic examples—recall the difference between the images produced by the GAN in chapter 3 and the DCGAN in chapter 4!
A Jupyter notebook with the full implementation, including added visualizations of the training progress, is available in our GitHub repository, under the chapter-8 folder: https://github.com/GANs-in-Action/gans-in-action. The code was tested with Python 3.6.0, Keras 2.1.6, and TensorFlow 1.8.0. To speed up the training time, we recommend running the model on a GPU.
You guessed it—the first step is to import all the modules and libraries needed for our model, as shown in the following listing.
%matplotlib inline import matplotlib.pyplot as plt import numpy as np from keras.datasets import mnist from keras.layers import ( Activation, BatchNormalization, Concatenate, Dense, Embedding, Flatten, Input, Multiply, Reshape) from keras.layers.advanced_activations import LeakyReLU from keras.layers.convolutional import Conv2D, Conv2DTranspose from keras.models import Model, Sequential from keras.optimizers import Adam
Just as before, we also specify the input image size, the size of the noise vector z, and the number of classes in our dataset, as shown here.
img_rows = 28 img_cols = 28 channels = 1 img_shape = (img_rows, img_cols, channels) 1 z_dim = 100 2 num_classes = 10 3
In this section, we implement the CGAN Generator. By now, you should be familiar with much of this network from chapters 4 and 7. The modifications made for the CGAN center around input handling, where we use embedding and element-wise multiplication to combine the random noise vector z and the label y into a joint representation. Let’s walk through what the code does:
Figure 8.4 illustrates the process, using the label 7 as an example.
First, we embed the label into a vector of the same size as z. Second, we multiply the corresponding elements of the embedded label and z (the symbol denotes element-wise multiplication). The resulting joined representation is then used as input into the CGAN Generator network.
And finally, the following listing shows what it all looks like in Python/Keras code.
def build_generator(z_dim): model = Sequential() model.add(Dense(256 * 7 * 7, input_dim=z_dim)) 1 model.add(Reshape((7, 7, 256))) model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))2 model.add(BatchNormalization()) 3 model.add(LeakyReLU(alpha=0.01)) 4 model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same')) 5 model.add(BatchNormalization()) 3 model.add(LeakyReLU(alpha=0.01)) 4 model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same')) 6 model.add(Activation('tanh')) 7 return model def build_cgan_generator(z_dim): z = Input(shape=(z_dim, )) 8 label = Input(shape=(1, ), dtype='int32') 9 label_embedding = Embedding(num_classes, z_dim, input_length=1)(label) 10 label_embedding = Flatten()(label_embedding) 11 joined_representation = Multiply()([z, label_embedding]) 12 generator = build_generator(z_dim) conditioned_img = generator(joined_representation) 13 return Model([z, label], conditioned_img)
Next, we implement the CGAN Discriminator. Just as in the previous section, the network architecture should look familiar to you, except for the piece where we handle the input image and its label. Here, too, we use the Keras Embedding layer to turn input labels into dense vectors. However, unlike the Generator, where the model input is a flat vector, the Discriminator receives three-dimensional images. This necessitates customized handling, described in the following steps:
Again, to make it less abstract, let’s see what the process looks like visually, using the label 7 as an example; see figure 8.5.
First, we embed the label into a vector the size of a flattened image (28 × 28 × 1 = 784). Second, we reshape the embedded label into a tensor with the same shape as the input image (28 × 28 × 1). Third, we concatenate the reshaped label that is embedding onto the corresponding image. This joined representation is then passed as input into the CGAN Discriminator network.
In addition to the preprocessing steps, we have to make a few additional adjustments to the Discriminator network compared to the one in chapter 4. (As in the previous chapter, basing the model on our DCGAN implementation should make it easier to see the CGAN-specific changes without distractions from implementation details in unrelated parts of the model.) First, we have to adjust the model input dimensions to (28 × 28 × 2) to reflect the new input shape.
Second, we increase the depth of the first convolutional layer from 32 to 64. The reasoning behind this change is that there is more information to encode because of the concatenated label embedding; this network architecture indeed yielded better results experimentally.
At the output layer, we use the sigmoid activation function to produce a probability that the input image-label pair is real rather than fake—no change here. And finally, the following listing is our CGAN Discriminator implementation.
def build_discriminator(img_shape): model = Sequential() model.add( 1 Conv2D(64, kernel_size=3, strides=2, input_shape=(img_shape[0], img_shape[1], img_shape[2] + 1), padding='same')) model.add(LeakyReLU(alpha=0.01)) 2 model.add( 3 Conv2D(64, kernel_size=3, strides=2, input_shape=img_shape, padding='same')) model.add(BatchNormalization()) 4 model.add(LeakyReLU(alpha=0.01)) 5 model.add( 6 Conv2D(128, kernel_size=3, strides=2, input_shape=img_shape, padding='same')) model.add(BatchNormalization()) 7 model.add(LeakyReLU(alpha=0.01)) 8 model.add(Flatten()) 9 model.add(Dense(1, activation='sigmoid')) return model def build_cgan_discriminator(img_shape): img = Input(shape=img_shape) 10 label = Input(shape=(1, ), dtype='int32') 11 label_embedding = Embedding(num_classes, 12 np.prod(img_shape), input_length=1)(label) label_embedding = Flatten()(label_embedding) 13 label_embedding = Reshape(img_shape)(label_embedding) 14 concatenated = Concatenate(axis=-1)([img, label_embedding]) 15 discriminator = build_discriminator(img_shape) classification = discriminator(concatenated) 16 return Model([img, label], classification)
Next, we build and compile the CGAN Discriminator and Generator models, as shown in the following listing. Notice that in the combined model used to train the Generator, the same input label is passed to the Generator (to generate a sample) and to the Discriminator (to make a prediction).
def build_cgan(generator, discriminator): z = Input(shape=(z_dim, )) 1 label = Input(shape=(1, )) 2 img = generator([z, label]) 3 classification = discriminator([img, label]) model = Model([z, label], classification) 4 return model discriminator = build_cgan_discriminator(img_shape) 5 discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy']) generator = build_cgan_generator(z_dim) 6 discriminator.trainable = False 7 cgan = build_cgan(generator, discriminator) 8 cgan.compile(loss='binary_crossentropy', optimizer=Adam())
For the CGAN training algorithm, the details of each training iteration are as follows.
For each training iteration do
End for
The following listing implements this CGAN training algorithm.
accuracies = [] losses = [] def train(iterations, batch_size, sample_interval): (X_train, y_train), (_, _) = mnist.load_data() 1 X_train = X_train / 127.5 - 1. 2 X_train = np.expand_dims(X_train, axis=3) real = np.ones((batch_size, 1)) 3 fake = np.zeros((batch_size, 1)) 4 for iteration in range(iterations): idx = np.random.randint(0, X_train.shape[0], batch_size) 5 imgs, labels = X_train[idx], y_train[idx] z = np.random.normal(0, 1, (batch_size, z_dim)) 6 gen_imgs = generator.predict([z, labels]) d_loss_real = discriminator.train_on_batch([imgs, labels], real) 7 d_loss_fake = discriminator.train_on_batch([gen_imgs, labels], fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) z = np.random.normal(0, 1, (batch_size, z_dim)) 8 labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)9 g_loss = cgan.train_on_batch([z, labels], real) 10 if (iteration + 1) % sample_interval == 0: print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % 11 (iteration + 1, d_loss[0], 100 * d_loss[1], g_loss)) losses.append((d_loss[0], g_loss)) 12 accuracies.append(100 * d_loss[1]) sample_images() 13
You may recognize the next function from chapters 3 and 4. We used it to examine how the quality of the Generator-produced images improved as the training progressed. The function in listing 8.7 is indeed similar, but a few crucial differences exist.
First, instead of a 4 × 4 grid of random handwritten digits, we are generating a 2 × 5 grid of numbers, 1 through 5 in the first row, and 6 through 9 in the second row. This allows us to inspect how well the CGAN Generator is learning to produce specific numerals. Second, we are displaying the label for each example by using the set_title() method.
def sample_images(image_grid_rows=2, image_grid_columns=5): z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))1 labels = np.arange(0, 10).reshape(-1, 1) 2 gen_imgs = generator.predict([z, labels]) 3 gen_imgs = 0.5 * gen_imgs + 0.5 4 fig, axs = plt.subplots(image_grid_rows, 5 image_grid_columns, figsize=(10, 4), sharey=True, sharex=True) cnt = 0 for i in range(image_grid_rows): for j in range(image_grid_columns): axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray') 6 axs[i, j].axis('off') axs[i, j].set_title("Digit: %d" % labels[cnt]) cnt += 1
Figure 8.6 shows sample output from this function and illustrates the improvement to the CGAN-produced numerals over the course of training.
And finally, let’s run the model we just implemented:
iterations = 12000 1 batch_size = 32 sample_interval = 1000 train(iterations, batch_size, sample_interval) 2
Figure 8.7 shows the images of digits produced by the CGAN Generator after it is fully trained. At each row, we instruct the Generator to synthesize a different numeral, from 0 to 9. Notice that each numeral is rendered in a different writing style, attesting to CGAN’s ability not only to learn to produce examples matching every label in the training dataset, but also to capture the full diversity of the training data.
In this chapter, you saw how labels could be used to guide the training of the Generator and the Discriminator to teach a GAN to produce fake examples of our choice. Along with the DCGAN, CGAN is one of the most influential early GAN variants that has inspired countless new research directions.
Perhaps the most impactful and promising of these is the use of conditional adversarial networks as a general-purpose solution to image-to-image translation problems. This is a class of problems that seeks to translate images from one modality into another. Applications of image-to-image translation range from colorizing black-and-white photos to turning a daytime scene into nighttime and synthesizing a satellite view from a map view.
One of the most successful early implementations based on the Conditional GAN paradigm is pix2pix, which uses pairs of images (one as the input and the other as the label) to learn to translate from one domain into another. Recall that, in theory as well as in practice, the conditioning information used to train a CGAN can be much more than just labels to provide for more complex use cases and scenarios. For example, for colorization tasks, an image pair would be a black-and-white photo (the input) and a colored version of the same photo (the label). You will see these illustrated in the following chapter.
We do not cover pix2pix in detail because only about a year after its publication, it was eclipsed by another GAN variant that not only outperformed pix2pix’s performance on image-to-image translation tasks but also accomplished it without the need for paired images. The Cycle-Consistent Adversarial Network (or CycleGAN, as the technique came to be known) needs only two groups of images representing the two domains (for example, a group of black-and-white photos and a group of colored photos). You will learn all about this remarkable GAN variant in the following chapter.