From 89439e46c0756f6c26abcae5223796e9deae0796 Mon Sep 17 00:00:00 2001 From: Codetroopa Date: Wed, 24 Oct 2018 18:51:16 -0700 Subject: [PATCH] Add replay functionality through the sc2_env --- rl/agents/a2c/runner.py | 7 +++++++ rl/environment.py | 7 +++++++ run.py | 7 +++++++ 3 files changed, 21 insertions(+) diff --git a/rl/agents/a2c/runner.py b/rl/agents/a2c/runner.py index 7376d60..31ed75b 100644 --- a/rl/agents/a2c/runner.py +++ b/rl/agents/a2c/runner.py @@ -12,6 +12,8 @@ class A2CRunner(): def __init__(self, agent, envs, + save_replay_episodes, + replay_dir, summary_writer=None, train=True, n_steps=8, @@ -27,6 +29,8 @@ def __init__(self, """ self.agent = agent self.envs = envs + self.save_replay_episodes = save_replay_episodes + self.replay_dir = replay_dir self.summary_writer = summary_writer self.train = train self.n_steps = n_steps @@ -51,6 +55,9 @@ def _summarize_episode(self, timestep): print("episode %d: score = %f" % (self.episode_counter, score)) self.episode_counter += 1 + + if self.save_replay_episodes != 0 and self.episode_counter % self.save_replay_episodes == 0: + self.envs.save_replay(self.replay_dir) return score def run_batch(self, train_summary=False): diff --git a/rl/environment.py b/rl/environment.py index 1843a05..de8ea71 100644 --- a/rl/environment.py +++ b/rl/environment.py @@ -60,6 +60,9 @@ def worker(remote, env_fn_wrapper): elif cmd == 'observation_spec': spec = env.observation_spec() remote.send(spec) + elif cmd == 'save_replay': + env.save_replay(action) + remote.send(None) else: raise NotImplementedError @@ -118,6 +121,10 @@ def observation_spec(self): specs = [remote.recv() for remote in self.remotes] return specs + def save_replay(self, replay_dir): + self.remotes[0].send(('save_replay', replay_dir)) + self.remotes[0].recv() + def make_sc2env(**kwargs): env = sc2_env.SC2Env(**kwargs) diff --git a/run.py b/run.py index 58f7d43..ea0ca98 100644 --- a/run.py +++ b/run.py @@ -66,6 +66,10 @@ help='root directory for checkpoint storage') parser.add_argument('--summary_dir', type=str, default=os.path.join('out','summary'), help='root directory for summary storage') +parser.add_argument('--save_replay_episodes', type=int, default=0, + help='number of episodes until a replay is recorded. 0 means no replay.') +parser.add_argument('--replay_dir', type=str, default=os.path.join('out','replays'), + help='root directory for replay storage') args = parser.parse_args() # TODO write args to config file and store together with summaries (https://pypi.python.org/pypi/ConfigArgParse) @@ -77,6 +81,7 @@ ckpt_path = os.path.join(args.save_dir, args.experiment_id) summary_type = 'train' if args.train else 'eval' summary_path = os.path.join(args.summary_dir, args.experiment_id, summary_type) +replay_path = os.path.abspath(os.path.join(args.replay_dir, args.experiment_id)) def _save_if_training(agent, summary_writer): @@ -124,6 +129,8 @@ def main(): runner = A2CRunner( envs=envs, agent=agent, + save_replay_episodes=args.save_replay_episodes, + replay_dir=replay_path, train=args.train, summary_writer=summary_writer, discount=args.discount,