Model training

Now let's kick off the training process and see how GANs will manage to generate images similar to the MNIST ones:

train_batch_size = 100
num_epochs = 100
generated_samples = []
model_losses = []

saver = tf.train.Saver(var_list = gen_vars)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

for e in range(num_epochs):
for ii in range(mnist_dataset.train.num_examples//train_batch_size):
input_batch = mnist_dataset.train.next_batch(train_batch_size)

# Get images, reshape and rescale to pass to D
input_batch_images = input_batch[0].reshape((train_batch_size, 784))
input_batch_images = input_batch_images*2 - 1

# Sample random noise for G
gen_batch_z = np.random.uniform(-1, 1, size=(train_batch_size, gen_z_size))

# Run optimizers
_ = sess.run(disc_train_optimizer, feed_dict={real_discrminator_input: input_batch_images, generator_input_z: gen_batch_z})
_ = sess.run(gen_train_optimizer, feed_dict={generator_input_z: gen_batch_z})

# At the end of each epoch, get the losses and print them out
train_loss_disc = sess.run(disc_loss, {generator_input_z: gen_batch_z, real_discrminator_input: input_batch_images})
train_loss_gen = gen_loss.eval({generator_input_z: gen_batch_z})

print("Epoch {}/{}...".format(e+1, num_epochs),
"Disc Loss: {:.3f}...".format(train_loss_disc),
"Gen Loss: {:.3f}".format(train_loss_gen))

# Save losses to view after training
model_losses.append((train_loss_disc, train_loss_gen))

# Sample from generator as we're training for viegenerator_inputs_zwing afterwards
gen_sample_z = np.random.uniform(-1, 1, size=(16, gen_z_size))
generator_samples = sess.run(
generator(generator_input_z, input_img_size, reuse_vars=True),
feed_dict={generator_input_z: gen_sample_z})

generated_samples.append(generator_samples)
saver.save(sess, './checkpoints/generator_ck.ckpt')

# Save training generator samples
with open('train_generator_samples.pkl', 'wb') as f:
pkl.dump(generated_samples, f)
Output:
.
.
.
Epoch 71/100... Disc Loss: 1.078... Gen Loss: 1.361 Epoch 72/100... Disc Loss: 1.037... Gen Loss: 1.555 Epoch 73/100... Disc Loss: 1.194... Gen Loss: 1.297 Epoch 74/100... Disc Loss: 1.120... Gen Loss: 1.730 Epoch 75/100... Disc Loss: 1.184... Gen Loss: 1.425 Epoch 76/100... Disc Loss: 1.054... Gen Loss: 1.534 Epoch 77/100... Disc Loss: 1.457... Gen Loss: 0.971 Epoch 78/100... Disc Loss: 0.973... Gen Loss: 1.688 Epoch 79/100... Disc Loss: 1.324... Gen Loss: 1.370 Epoch 80/100... Disc Loss: 1.178... Gen Loss: 1.710 Epoch 81/100... Disc Loss: 1.070... Gen Loss: 1.649 Epoch 82/100... Disc Loss: 1.070... Gen Loss: 1.530 Epoch 83/100... Disc Loss: 1.117... Gen Loss: 1.705 Epoch 84/100... Disc Loss: 1.042... Gen Loss: 2.210 Epoch 85/100... Disc Loss: 1.152... Gen Loss: 1.260 Epoch 86/100... Disc Loss: 1.327... Gen Loss: 1.312 Epoch 87/100... Disc Loss: 1.069... Gen Loss: 1.759 Epoch 88/100... Disc Loss: 1.001... Gen Loss: 1.400 Epoch 89/100... Disc Loss: 1.215... Gen Loss: 1.448 Epoch 90/100... Disc Loss: 1.108... Gen Loss: 1.342 Epoch 91/100... Disc Loss: 1.227... Gen Loss: 1.468 Epoch 92/100... Disc Loss: 1.190... Gen Loss: 1.328 Epoch 93/100... Disc Loss: 0.869... Gen Loss: 1.857 Epoch 94/100... Disc Loss: 0.946... Gen Loss: 1.740 Epoch 95/100... Disc Loss: 0.925... Gen Loss: 1.708 Epoch 96/100... Disc Loss: 1.067... Gen Loss: 1.427 Epoch 97/100... Disc Loss: 1.099... Gen Loss: 1.573 Epoch 98/100... Disc Loss: 0.972... Gen Loss: 1.884 Epoch 99/100... Disc Loss: 1.292... Gen Loss: 1.610 Epoch 100/100... Disc Loss: 1.103... Gen Loss: 1.736

After running the model for 100 epochs, we have a trained model that will be able to generate images similar to the original input images that we fed to the discriminator:

fig, ax = plt.subplots()
model_losses = np.array(model_losses)
plt.plot(model_losses.T[0], label='Disc loss')
plt.plot(model_losses.T[1], label='Gen loss')
plt.title("Model Losses")
plt.legend()

Output:

Figure 6: Discriminator and Generator Losses

As shown in the preceding figure, you can see that the model losses, which are represented by the Discriminator and Generator lines, are converging.

..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset