In this section, we will show you how to implement a relatively simple CNN architecture. We will also look at how to train it to classify the CIFAR-10 dataset.
Start by importing all the necessary libraries:
import fire import numpy as np import os import tensorflow as tf from tf.keras.datasets import cifar10
We will define a Python class that will implement the training process. The class name is Train, and it implements two methods: build_graph and train. The train function is fired when the main program is executed:
class Train: __x_ = []
__y_ = []
__logits = []
__loss = []
__train_step = []
__merged_summary_op = []
__saver = []
__session = []
__writer = []
__is_training = []
__loss_val = []
__train_summary = []
__val_summary = []
def __init__(self):
pass def build_graph(self): [...] def train(self, save_dir='./save', batch_size=500): [...] if __name__ == '__main__': cnn= Train() cnn.train