Now it's time to define the model losses. First off, the discriminator loss will be divided into two parts:
- One which will represent the GAN problem, which is the unsupervised loss
- The second one will compute the individual actual class probabilities, which is the supervised loss
For the discriminator's unsupervised loss, it has to discriminate between actual training images and the generated images by the generator.
As for a regular GAN, half of the time, the discriminator will get unlabeled images from the training set as an input and the other half, fake, unlabeled images from the generator.
For the second part of the discriminator loss, which is the supervised loss, we need to build upon the logits from the discriminator. So, we will use the softmax cross entropy since it's a multi classification problem.
As mentioned in the Enhanced Techniques for Training GANs paper, we should use feature matching for the generator loss. As the authors describe:
And finally, the model loss function will look like this:
def model_losses(input_actual, input_latent_z, output_dim, target, num_classes, label_mask, leaky_alpha=0.2,
drop_out_rate=0.):
# These numbers multiply the size of each layer of the generator and the discriminator,
# respectively. You can reduce them to run your code faster for debugging purposes.
gen_size_mult = 32
disc_size_mult = 64
# Here we run the generator and the discriminator
gen_model = generator(input_latent_z, output_dim, leaky_alpha=leaky_alpha, size_mult=gen_size_mult)
disc_on_data = discriminator(input_actual, leaky_alpha=leaky_alpha, drop_out_rate=drop_out_rate,
size_mult=disc_size_mult)
disc_model_real, class_logits_on_data, gan_logits_on_data, data_features = disc_on_data
disc_on_samples = discriminator(gen_model, reuse_vars=True, leaky_alpha=leaky_alpha,
drop_out_rate=drop_out_rate, size_mult=disc_size_mult)
disc_model_fake, class_logits_on_samples, gan_logits_on_samples, sample_features = disc_on_samples
# Here we compute `disc_loss`, the loss for the discriminator.
disc_loss_actual = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=gan_logits_on_data,
labels=tf.ones_like(gan_logits_on_data)))
disc_loss_fake = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(logits=gan_logits_on_samples,
labels=tf.zeros_like(gan_logits_on_samples)))
target = tf.squeeze(target)
classes_cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=class_logits_on_data,
labels=tf.one_hot(target,
num_classes + extra_class,
dtype=tf.float32))
classes_cross_entropy = tf.squeeze(classes_cross_entropy)
label_m = tf.squeeze(tf.to_float(label_mask))
disc_loss_class = tf.reduce_sum(label_m * classes_cross_entropy) / tf.maximum(1., tf.reduce_sum(label_m))
disc_loss = disc_loss_class + disc_loss_actual + disc_loss_fake
# Here we set `gen_loss` to the "feature matching" loss invented by Tim Salimans.
sampleMoments = tf.reduce_mean(sample_features, axis=0)
dataMoments = tf.reduce_mean(data_features, axis=0)
gen_loss = tf.reduce_mean(tf.abs(dataMoments - sampleMoments))
prediction_class = tf.cast(tf.argmax(class_logits_on_data, 1), tf.int32)
check_prediction = tf.equal(tf.squeeze(target), prediction_class)
correct = tf.reduce_sum(tf.to_float(check_prediction))
masked_correct = tf.reduce_sum(label_m * tf.to_float(check_prediction))
return disc_loss, gen_loss, correct, masked_correct, gen_model