Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for training on MADRAS env #4

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion baselines/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def learn(
network,
env,
seed=None,
nsteps=5,
nsteps=20,
total_timesteps=int(80e6),
vf_coef=0.5,
ent_coef=0.01,
Expand Down Expand Up @@ -187,6 +187,7 @@ def learn(
nenvs = env.num_envs
policy = build_policy(env, network, **network_kwargs)

print('Parallel %d number'%(nenvs))
# Instantiate the model object (that creates step_model and train_model)
model = Model(policy=policy, env=env, nsteps=nsteps, ent_coef=ent_coef, vf_coef=vf_coef,
max_grad_norm=max_grad_norm, lr=lr, alpha=alpha, epsilon=epsilon, total_timesteps=total_timesteps, lrschedule=lrschedule)
Expand Down
3 changes: 2 additions & 1 deletion baselines/a2c/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def run(self):
self.dones = dones
for n, done in enumerate(dones):
if done:
self.obs[n] = self.obs[n]*0
# self.obs[n] = self.obs[n]*0
pass
self.obs = obs
mb_rewards.append(rewards)
mb_dones.append(self.dones)
Expand Down
14 changes: 14 additions & 0 deletions baselines/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,20 @@ def network_fn(X):

return network_fn

@register("mynn")
def mynn(num_layers=2, num_hidden=[300,500], activation=tf.tanh, layer_norm=False):
def network_fn(X):
h = tf.layers.flatten(X)
for i in range(num_layers):
h = fc(h, 'mlp_fc{}'.format(i), nh=num_hidden[i], init_scale=np.sqrt(2))
if layer_norm:
h = tf.contrib.layers.layer_norm(h, center=True, scale=True)
h = activation(h)

return h

return network_fn


@register("cnn")
def cnn(**conv_kwargs):
Expand Down
5 changes: 5 additions & 0 deletions baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def reset(self):
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])

def reset_envno(self,no):
self._assert_not_closed()
self.remotes[no].send(('reset', None))
return self.remotes[no].recv()

def close_extras(self):
self.closed = True
if self.waiting:
Expand Down
10 changes: 6 additions & 4 deletions baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
'SpaceInvaders-Snes',
}

_game_envs['madras'] = {'gym-torcs-v0','gym-madras-v0'}
_game_envs['madras'] = {'Madras-v0'}
def train(args, extra_args):
env_type, env_id = get_env_type(args.env)
print('env_type: {}'.format(env_type))
Expand Down Expand Up @@ -88,6 +88,7 @@ def build_env(args):
ncpu = multiprocessing.cpu_count()
if sys.platform == 'darwin': ncpu //= 2
nenv = args.num_env or ncpu
print('Found %d CPUs'%(nenv))
alg = args.alg
seed = args.seed

Expand Down Expand Up @@ -196,23 +197,24 @@ def main(args):
rank = MPI.COMM_WORLD.Get_rank()

model, env = train(args, extra_args)
env.close()
# env.close()

if args.save_path is not None and rank == 0:
save_path = osp.expanduser(args.save_path)
model.save(save_path)

if args.play:
logger.log("Running trained model")
env = build_env(args)
# env = build_env(args)
obs = env.reset()
def initialize_placeholders(nlstm=128,**kwargs):
return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
state, dones = initialize_placeholders(**extra_args)
while True:
actions, _, state, _ = model.step(obs,S=state, M=dones)
# actions, _, state, _ = model.step(obs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Used in ddpg model play.

obs, _, done, _ = env.step(actions)
env.render()
# env.render()
done = done.any() if isinstance(done, np.ndarray) else done

if done:
Expand Down