Simple linear regression example

To better digest all the concepts, let's now create a simple linear regression model. First, we have to import all the libraries and set a random seed, both for NumPy and TensorFlow (so that we'll all have the same results):

import tensorflow as tf
import numpy as np
from datetime import datetime

np.random.seed(10)
tf.set_random_seed(10)

Then, we can create a synthetic dataset consisting of 100 examples, as shown in the following screenshot:

Figure 2.4: Dataset used in the linear regression example

Because this is a linear regression example, y = W * X + b, where W and b are arbitrary values. In this example, we set W = 0.5 and b = 1.4. Additionally, we add some normal random noise:

W, b = 0.5, 1.4
# create a dataset of 100 examples
X = np.linspace(0,100, num=100)
# add random noise to the y labels
y = np.random.normal(loc=W * X + b, scale=2.0, size=len(X))

The next step involves creating the placeholders for the input and the output, and the variables of the weight and bias of the linear model. During training, these two variables will be optimized to be as similar as possible to the weight and bias of the dataset:

# create the placeholders
x_ph = tf.placeholder(shape=[None,], dtype=tf.float32)
y_ph = tf.placeholder(shape=[None,], dtype=tf.float32)

# create the variables
v_weight = tf.get_variable("weight", shape=[1], dtype=tf.float32)
v_bias = tf.get_variable("bias", shape=[1], dtype=tf.float32)

Then, we build the computational graph defining the linear operation and the mean squared error (MSE) loss:

# linear computation
out = v_weight * x_ph + v_bias

# compute the mean squared error
loss = tf.reduce_mean((out - y_ph)**2)

We can now instantiate the optimizer and call minimize() to minimize the MSE loss. minimize() first computes the gradients of the variables (v_weight and v_bias) and then applies the gradient, updating the variables:

opt = tf.train.AdamOptimizer(0.4).minimize(loss)

Now, let's create a session and initialize all the variables:

session = tf.Session()
session.run(tf.global_variables_initializer())

The training is done by running the optimizer multiple times while feeding the dataset to the graph. To keep track of the state of the model, the MSE loss and the model variables (weight and bias) are printed every 40 epochs:

# loop to train the parameters
for ep in range(210):
# run the optimizer and get the loss
train_loss, _ = session.run([loss, opt], feed_dict={x_ph:X, y_ph:y})

# print epoch number and loss
if ep % 40 == 0:
print('Epoch: %3d, MSE: %.4f, W: %.3f, b: %.3f' % (ep, train_loss, session.run(v_weight), session.run(v_bias)))

In the end, we can print the final values of the variables:

print('Final weight: %.3f, bias: %.3f' % (session.run(v_weight), session.run(v_bias)))

The output will be similar to the following:

>>  Epoch: 0, MSE: 4617.4390, weight: 1.295, bias: -0.407
Epoch: 40, MSE: 5.3334, weight: 0.496, bias: -0.727
Epoch: 80, MSE: 4.5894, weight: 0.529, bias: -0.012
Epoch: 120, MSE: 4.1029, weight: 0.512, bias: 0.608
Epoch: 160, MSE: 3.8552, weight: 0.506, bias: 1.092
Epoch: 200, MSE: 3.7597, weight: 0.501, bias: 1.418
Final weight: 0.500, bias: 1.473

During the training phase, it's possible to see that the MSE loss would decrease toward a non-zero value (of about 3.71). That's because we added random noise to the dataset that prevents the MSE from reaching a perfect value of 0.

Also, as anticipated, with regard to the weight and bias of the model approach, the values of 0.500 and 1.473 are precisely the values around which the dataset has been built. The blue line visible in the following screenshot is the prediction of the trained linear model, while the points are our training examples:

Figure 2.5: Linear regression model predictions
For all the color references in the chapter, please refer to the color images bundle: http://www.packtpub.com/sites/default/files/downloads/9781789131116_ColorImages.pdf.
..................Content has been hidden....................

You can't read the all page of ebook, please click here login for view all page.
Reset