Saving and restoring neural networks

There are two ways of storing a trained neural network for future use and then restoring it. We will see that they enable this in the convolutional neural network example.

The first one lives in tf.train. It is created with the following statement:

saver = tf.train.Saver(max_to_keep=10)

And then each training step can be saved with:

saver.save(sess, './classifier', global_step=step)

Here the full graph is saved, but it is possible to only save part of it. We save it all here, and only keep the last 10 saves, and we postfix the name of the save with the step we are at.

Let's say that we saved the final training step with saver.save(sess, './classifier-final'). We know we first have to restore the graph state with:

new_saver = tf.train.import_meta_graph("classifier-final.meta")

This didn't restore the variable state, for this we have to call:

new_saver.restore(sess, tf.train.latest_checkpoint('./'))
Be aware that only the graph is restored. If you have Python variables pointing to nodes in this graph, you need to restore them before you can use them. This is true for placeholders and operations.

We also have to restore some of our variables:

graph = tf.get_default_graph()
training_tf = graph.get_tensor_by_name('is_training:0')

This is also a good reason to use proper names for all tensors (placeholders, operations, and so on) as we need to use their name to get a reference to them again when we restore the graph.

The other mechanism builds on this one and is far more powerful, but we will present the basic usage that mimics the simple one. We first create what is usually called a builder:

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
The export_dir folder is created by the builder here. If it already exists, you have to remove it before creating a new saved model.

Now after the training, we can call it to save the state of the network:

builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING])

Obviously, we can save more than one network in this object, with far more attributes, but, in our case, we just need to call one function to restore the state:

tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], export_dir)
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset