Skip to content

Trajectory-ranked Reward EXtrapolation (T-REX) for Inverse Reinforcement Learning - A Tensorflow implementation trained on OpenAI Gym environments

License

Notifications You must be signed in to change notification settings

msinto93/T-REX-IRL

Repository files navigation

T-REX-IRL

Trajectory-ranked Reward EXtrapolation (T-REX) - A Tensorflow implementation trained on OpenAI Gym environments.

From the paper Extrapolating Beyond Suboptimal Demonstrations via Inverse Reinforcement Learning from Observations.

T-REX is able to learn a reward function from a set of ranked low-scoring demonstrations, from which a policy can then be obtained (via a reinforcement learning algorithm) which significantly outperforms the suboptimal demonstrations.

This implementation has been trained and tested on OpenAI Gym Atari environments, achieving scores much greater than any seen in the original demonstrations.

For the reinforcement learning algorithm (for generating the initial demonstrations and training the final policy on the learned reward function), the OpenAI Baselines implementation of Proximal Policy Optimisation (PPO) is used, modified slightly to allow a choice between learning from the true reward from the environment (default) or instead supplying a learned reward function (the trained T-REX network).

Note: Only 2 files from the OpenAI Baselines repo have been modified:

  • baselines/ppo2/ppo2.py - Added extra args to the call to runner to choose between using the default environment reward or a learned reward function.
  • baselines/ppo2/runner.py - Added the functionality to load and run inference on the learned reward function if this option is chosen.

Otherwise the rest of the repo is an exact clone.

Requirements

Note: Versions stated are the versions I used, however this will still likely work with other versions.

  • any other prerequisites for running OpenAI Baselines code, listed here

Usage

Note: This example will show usage for the 'Breakout' environment, to use any other environment simply modify the --env parameter.

  • The first step is to train the default OpenAI Baselines PPO algorithm in the environment, frequently saving checkpoints (every --save_interval training updates) to be able to generate varying quality of demonstrations from different stages of the training:
  $ python -m baselines.run --alg=ppo2 --env='BreakoutNoFrameskip-v4' --num_timesteps=10e6 --save_interval=50

This will save the checkpoints in a folder in the /tmp directory based on the time and date (e.g. /tmp/openai-2019-05-27-18-26-59-016163/checkpoints). Note that once the episode reward starts exceeding the reward of the demonstrations used in the paper, this training can be manually stopped (as we will not use any demonstrations which have a reward greater than those used in the paper, to make the results comparable).

  • The next step is to then generate the demonstration samples from these checkpoints:
  $ python generate_demonstrations.py --env='Breakout' --ckpt_dir='/tmp/openai-2019-05-27-18-26-59-016163/checkpoints'
  • The T-REX reward network is then trained on these demonstration samples, by running:
  $ python train_trex.py --env='Breakout' --ckpt_dir='./ckpts/Breakout'

Note that this time the --ckpt_dir is where the checkpoints for the T-REX network should be saved.

  • We then train the OpenAI Baselines PPO algorithm, similar to before, however this time using the learned reward function (the T-REX network) to provide the reward rather than the true environment reward. The algorithm will load the latest checkpoint in the --reward_ckpt_dir and use this network to provide the reward for training. As in the paper, we run for 50 million frames:
  $ python -m baselines.run --alg=ppo2 --env='BreakoutNoFrameskip-v4' --num_timesteps=50e6 --save_interval=5000 --learned_reward=True --reward_ckpt_dir='./ckpts/Breakout' 

As before, this will save the checkpoints in a folder in the /tmp' directory based on the time and date (e.g. /tmp/openai-2019-05-29-22-48-24-125657/checkpoints).

  • Finally, we can test the policy by running it in the environment and, as in the paper, taking the best average performance over 3 random seeds with 30 trials per seed:
  $ python test_learned_policy.py --env='Breakout' --ckpt_dir='/tmp/openai-2019-05-29-22-48-24-125657/checkpoints'

Results

Results are obtained by taking the PPO algorithm trained on the learned reward function and running it in the environment, as explained above. Unless stated, the checkpoint used for testing is the last one saved during training. Best score for each environment (based on the ground truth reward) between the paper and this implementation is highlighted in bold.

Results from this implementation:

Results from the paper:

Note: To get the best and average reward values across the demonstrations for each environment, run

  $ python utils/get_demonstration_stats.py --env='Breakout'

changing the --env parameter each time.

Differences

There are some minor differences between this implementation of T-REX and that used in the paper:

  • This implementation subsamples trajectory pairs from the saved demonstrations live during training; the paper does this offline as a preprocessing step before training, subsampling 6,000 trajectory pairs from the saved demonstrations then training on these subsamples.

  • This implementation uses a fixed trajectory length of 50 when subsampling from the demonstrations; the paper chooses a random trajectory length each time between 50 and 100.

  • This implementation trains on a batch of trajectory pairs at each step (batch size = 16), where a batch is made up of the unrolled states of the 16 trajectory pairs combined; the paper trains on a single trajectory pair at each step (where the 'batch' is just the unrolled trajectory states of 1 trajectory pair).

License

MIT License

About

Trajectory-ranked Reward EXtrapolation (T-REX) for Inverse Reinforcement Learning - A Tensorflow implementation trained on OpenAI Gym environments

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages