From 93d7d4637e390313ebfd42811b187cab18b8ee09 Mon Sep 17 00:00:00 2001 From: Andrew Szot Date: Sat, 19 May 2018 01:29:50 -0700 Subject: [PATCH] all training complete --- a2c.py | 9 +++++---- env_model.py | 2 +- i2a.py | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/a2c.py b/a2c.py index 76e7745..6e2b7bc 100644 --- a/a2c.py +++ b/a2c.py @@ -259,7 +259,7 @@ def _thunk(): episode_rewards = np.zeros((nenvs, )) final_rewards = np.zeros((nenvs, )) - for update in tqdm(range(load_count + 1, total_timesteps)): + for update in tqdm(range(load_count + 1, total_timesteps + 1)): # mb stands for mini batch mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[] for n in range(nsteps): @@ -331,10 +331,11 @@ def _thunk(): if __name__ == '__main__': os.environ["CUDA_VISIBLE_DEVICES"]="0" - load_count = 100000 - load_path = 'weights/model_%i.ckpt' % load_count + load_count = 0 + load_path = 'weights/a2c_%i.ckpt' % load_count load_path = None - train(CnnPolicy, 'a2c', load_count, load_path, './a2c_logs') + train(CnnPolicy, 'a2c', load_count=load_count, load_path=load_path, + log_path='./a2c_logs') diff --git a/env_model.py b/env_model.py index 940bbc5..cf96e94 100644 --- a/env_model.py +++ b/env_model.py @@ -164,7 +164,7 @@ def __init__(self, imag_state, imag_reward, input_states, input_actions, os.environ["CUDA_VISIBLE_DEVICES"]="1" with tf.Session() as sess: actor_critic = get_actor_critic(sess, nenvs, nsteps, ob_space, ac_space, CnnPolicy, should_summary=False) - actor_critic.load('weights/model_100000.ckpt') + actor_critic.load('weights/a2c_200000.ckpt') with tf.variable_scope('env_model'): env_model = create_env_model(ob_space, num_actions, num_pixels, len(mode_rewards['regular'])) diff --git a/i2a.py b/i2a.py index 721c162..2b4cca4 100644 --- a/i2a.py +++ b/i2a.py @@ -135,7 +135,7 @@ def get_cache_loaded_a2c(sess, nenvs, nsteps, ob_space, ac_space): with tf.variable_scope('actor'): g_actor_critic = get_actor_critic(sess, nenvs, nsteps, ob_space, ac_space, CnnPolicy, should_summary=False) - g_actor_critic.load('weights/model_100000.ckpt') + g_actor_critic.load('weights/a2c_200000.ckpt') print('Actor restored!') return g_actor_critic