diff --git a/policy_gradient_reinforce_tf2.py b/policy_gradient_reinforce_tf2.py index e73e03f..8b203fa 100644 --- a/policy_gradient_reinforce_tf2.py +++ b/policy_gradient_reinforce_tf2.py @@ -36,7 +36,14 @@ def update_network(network, rewards, states, actions, num_actions): discounted_rewards -= np.mean(discounted_rewards) discounted_rewards /= np.std(discounted_rewards) states = np.vstack(states) - loss = network.train_on_batch(states, discounted_rewards) + # the following commented next line for training is not working + # loss = network.train_on_batch(states, discounted_rewards) + # to fix this we make two changes, + # 1. one_hot_encode the actions + one_hot_encode = np.array([[1 if a==i else 0 for i in range(2)] for a in actions]) + # 2. pass the discounted rewards using 'sample_weight' parameter of 'categorical_crossentropy' loss function + loss = network.train_on_batch(states,target_actions, sample_weight=discounted_rewards) + return loss