Skip to content

Commit

Permalink
all training complete
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Szot committed May 19, 2018
1 parent 26d2587 commit 93d7d46
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _thunk():
episode_rewards = np.zeros((nenvs, ))
final_rewards = np.zeros((nenvs, ))

for update in tqdm(range(load_count + 1, total_timesteps)):
for update in tqdm(range(load_count + 1, total_timesteps + 1)):
# mb stands for mini batch
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
for n in range(nsteps):
Expand Down Expand Up @@ -331,10 +331,11 @@ def _thunk():
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"]="0"

load_count = 100000
load_path = 'weights/model_%i.ckpt' % load_count
load_count = 0
load_path = 'weights/a2c_%i.ckpt' % load_count
load_path = None

train(CnnPolicy, 'a2c', load_count, load_path, './a2c_logs')
train(CnnPolicy, 'a2c', load_count=load_count, load_path=load_path,
log_path='./a2c_logs')


2 changes: 1 addition & 1 deletion env_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def __init__(self, imag_state, imag_reward, input_states, input_actions,
os.environ["CUDA_VISIBLE_DEVICES"]="1"
with tf.Session() as sess:
actor_critic = get_actor_critic(sess, nenvs, nsteps, ob_space, ac_space, CnnPolicy, should_summary=False)
actor_critic.load('weights/model_100000.ckpt')
actor_critic.load('weights/a2c_200000.ckpt')

with tf.variable_scope('env_model'):
env_model = create_env_model(ob_space, num_actions, num_pixels, len(mode_rewards['regular']))
Expand Down
2 changes: 1 addition & 1 deletion i2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_cache_loaded_a2c(sess, nenvs, nsteps, ob_space, ac_space):
with tf.variable_scope('actor'):
g_actor_critic = get_actor_critic(sess, nenvs, nsteps, ob_space,
ac_space, CnnPolicy, should_summary=False)
g_actor_critic.load('weights/model_100000.ckpt')
g_actor_critic.load('weights/a2c_200000.ckpt')
print('Actor restored!')
return g_actor_critic

Expand Down

0 comments on commit 93d7d46

Please sign in to comment.