Skip to content

Commit

Permalink
move replace target to graph building
Browse files Browse the repository at this point in the history
  • Loading branch information
MorvanZhou authored and Morvan Zhou committed Aug 15, 2017
1 parent 673ef4c commit 0f546db
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 30 deletions.
11 changes: 5 additions & 6 deletions contents/5.1_Double_DQN/RL_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(
self.learn_step_counter = 0
self.memory = np.zeros((self.memory_size, n_features*2+2))
self._build_net()
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

if sess is None:
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
Expand Down Expand Up @@ -114,14 +118,9 @@ def choose_action(self, observation):
action = np.random.randint(0, self.n_actions)
return action

def _replace_target_params(self):
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])

def learn(self):
if self.learn_step_counter % self.replace_target_iter == 0:
self._replace_target_params()
self.sess.run(self.replace_target_op)
print('\ntarget_params_replaced\n')

if self.memory_counter > self.memory_size:
Expand Down
10 changes: 4 additions & 6 deletions contents/5.2_Prioritized_Replay_DQN/RL_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def __init__(
self.learn_step_counter = 0

self._build_net()
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

if self.prioritized:
self.memory = Memory(capacity=memory_size)
Expand Down Expand Up @@ -254,14 +257,9 @@ def choose_action(self, observation):
action = np.random.randint(0, self.n_actions)
return action

def _replace_target_params(self):
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])

def learn(self):
if self.learn_step_counter % self.replace_target_iter == 0:
self._replace_target_params()
self.sess.run(self.replace_target_op)
print('\ntarget_params_replaced\n')

if self.prioritized:
Expand Down
11 changes: 5 additions & 6 deletions contents/5.3_Dueling_DQN/RL_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(
self.learn_step_counter = 0
self.memory = np.zeros((self.memory_size, n_features*2+2))
self._build_net()
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

if sess is None:
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
Expand Down Expand Up @@ -124,14 +128,9 @@ def choose_action(self, observation):
action = np.random.randint(0, self.n_actions)
return action

def _replace_target_params(self):
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])

def learn(self):
if self.learn_step_counter % self.replace_target_iter == 0:
self._replace_target_params()
self.sess.run(self.replace_target_op)
print('\ntarget_params_replaced\n')

sample_index = np.random.choice(self.memory_size, size=self.batch_size)
Expand Down
10 changes: 4 additions & 6 deletions contents/5_Deep_Q_Network/RL_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(

# consist of [target_net, evaluate_net]
self._build_net()
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

self.sess = tf.Session()

Expand Down Expand Up @@ -132,15 +135,10 @@ def choose_action(self, observation):
action = np.random.randint(0, self.n_actions)
return action

def _replace_target_params(self):
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])

def learn(self):
# check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self._replace_target_params()
self.sess.run(self.replace_target_op)
print('\ntarget_params_replaced\n')

# sample batch memory from all memory
Expand Down
10 changes: 4 additions & 6 deletions contents/6_OpenAI_gym/RL_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def __init__(

# consist of [target_net, evaluate_net]
self._build_net()
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.replace_target_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)]

self.sess = tf.Session()

Expand Down Expand Up @@ -132,15 +135,10 @@ def choose_action(self, observation):
action = np.random.randint(0, self.n_actions)
return action

def _replace_target_params(self):
t_params = tf.get_collection('target_net_params')
e_params = tf.get_collection('eval_net_params')
self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])

def learn(self):
# check to replace target parameters
if self.learn_step_counter % self.replace_target_iter == 0:
self._replace_target_params()
self.sess.run(self.replace_target_op)
print('\ntarget_params_replaced\n')

# sample batch memory from all memory
Expand Down

0 comments on commit 0f546db

Please sign in to comment.