forked from lloydwindrim/connect4
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
134 lines (113 loc) · 4.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import tensorflow as tf
import numpy as np
from collections import deque
from rl.deep_q_network import DeepQNetwork
from game import Game
import useful_tools
# initialize game env
env = Game([6, 7], 4)
# initialize tensorflow
sess = tf.Session()
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.0001, decay=0.9)
writer = tf.summary.FileWriter("logs/value_network", sess.graph)
# prepare custom tensorboard summaries
episode_reward = tf.Variable(0.)
tf.summary.scalar("Last 100 Episodes Average Episode Reward", episode_reward)
summary_vars = [episode_reward]
summary_placeholders = [tf.placeholder("float")
for i in range(len(summary_vars))]
summary_ops = [summary_vars[i].assign(summary_placeholders[i])
for i in range(len(summary_vars))]
# define policy neural network
state_dim = np.product(env.boardSize)
num_actions = env.boardSize[1]
def value_network(states):
W1 = tf.get_variable("W1", [state_dim, 256],
initializer=tf.random_normal_initializer(stddev=0.1))
b1 = tf.get_variable("b1", [256],
initializer=tf.constant_initializer(0))
h1 = tf.nn.relu(tf.matmul(states, W1) + b1)
W2 = tf.get_variable("W2", [256, 64],
initializer=tf.random_normal_initializer(stddev=0.1))
b2 = tf.get_variable("b2", [64],
initializer=tf.constant_initializer(0))
h2 = tf.nn.relu(tf.matmul(h1, W2) + b2)
Wo = tf.get_variable("Wo", [64, num_actions],
initializer=tf.random_normal_initializer(stddev=0.1))
bo = tf.get_variable("bo", [num_actions],
initializer=tf.constant_initializer(0))
p = tf.matmul(h2, Wo) + bo
return p
summaries = tf.summary.merge_all()
q_network = DeepQNetwork(sess,
optimizer,
value_network,
state_dim,
num_actions,
init_exp=0.6, # initial exploration prob
final_exp=0.1, # final exploration prob
anneal_steps=120000,
# N steps for annealing exploration
discount_factor=0.8) # no need for discounting
# load checkpoint if there is any
saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state("model")
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("successfully loaded checkpoint")
# how many episodes to train
training_episodes = 200000
# store episodes history
episode_history = deque(maxlen=100)
lost = 0
draw = 0
won = 0
cheated = 0
# start training
reward = 0.0
for i_episode in xrange(training_episodes):
state = np.array(env.reset())
done = False
t = 0
while not done:
if t != 0:
action = q_network.eGreedyAction(state[np.newaxis, :])
else:
action = np.random.randint(env.boardSize[1])
while useful_tools.isColumnFull(action + 1, env.gameState):
action = np.random.randint(env.boardSize[1])
next_state, reward, done = env.step(action)
q_network.storeExperience(state, action, reward, next_state, done)
q_network.updateModel()
state = np.array(next_state)
t += 1
if reward == -33:
cheated += 1
elif reward == -1:
lost += 1
elif reward == 100:
won += 1
elif reward == 10:
draw += 1
episode_history.append(reward)
# print status every 100 episodes
if i_episode % 100 == 0:
mean_rewards = np.mean(episode_history)
print("Episode {}".format(i_episode))
print("Reward for this episode: {}".format(reward))
print("Average reward for last 100 episodes: {}".format(mean_rewards))
print("cheated:" + str(cheated))
print("lost:" + str(lost))
print("won:" + str(won))
print("draw:" + str(draw))
# update tensorboard
sess.run(summary_ops[0], feed_dict={
summary_placeholders[0]: float(mean_rewards)})
result = sess.run(summaries)
writer.add_summary(result, i_episode)
lost = 0
draw = 0
won = 0
cheated = 0
# save checkpoint
saver.save(sess, "model/saved_network")