The value function that baseline approximated with a neural network can be implemented by adding a few lines to our previous code:
- Add the neural network, the operations for computing the MSE loss function, and the optimization procedure to the computational graph:
...
# placeholder that will contain the reward to go values (i.e. the y values)
rtg_ph = tf.placeholder(shape=(None,), dtype=tf.float32, name='rtg')
# MLP value function
s_values = tf.squeeze(mlp(obs_ph, hidden_sizes, 1, activation=tf.tanh))
# MSE loss function
v_loss = tf.reduce_mean((rtg_ph - s_values)**2)
# value function optimization
v_opt = tf.train.AdamOptimizer(vf_lr).minimize(v_loss)
...
- Run s_values, and store the predictions, as later we'll need to compute . This operation can be done in the innermost cycle (the differences from the REINFORCE code are shown in bold):
...
# besides act_multn, run also s_values
act, val = sess.run([act_multn, s_values], feed_dict={obs_ph:[obs]})
obs2, rew, done, _ = env.step(np.squeeze(act))
# add the new transition, included the state value predictions
env_buf.append([obs.copy(), rew, act, np.squeeze(val)])
...
- Retrieve rtg_batch, which contains the "target" values from the buffer, and optimize the value function:
obs_batch, act_batch, ret_batch, rtg_batch = buffer.get_batch()
sess.run([p_opt, v_opt], feed_dict={obs_ph:obs_batch, act_ph:act_batch, ret_ph:ret_batch, rtg_ph:rtg_batch})
- Compute the reward to go (), and the target values . This change is done in the Buffer class. We have to create a new empty self.rtg list in the initialization method of the class, and modify the store and get_batch functions, as follows:
def store(self, temp_traj):
if len(temp_traj) > 0:
self.obs.extend(temp_traj[:,0])
rtg = discounted_rewards(temp_traj[:,1], self.gamma)
# ret = G - V
self.ret.extend(rtg - temp_traj[:,3])
self.rtg.extend(rtg)
self.act.extend(temp_traj[:,2])
def get_batch(self):
return self.obs, self.act, self.ret, self.rtg
You can now test the REINFORCE with baseline algorithm on whatever environment you want, and compare the performance with the basic REINFORCE implementation.