Skip to content

Commit

Permalink
Merge pull request #20 from araffin/add-trpo
Browse files Browse the repository at this point in the history
Add TRPO + Hyperparameter optimization
  • Loading branch information
araffin authored May 11, 2019
2 parents b76641e + b923bf6 commit 64267d3
Show file tree
Hide file tree
Showing 74 changed files with 1,218 additions and 263 deletions.
26 changes: 0 additions & 26 deletions .github/ISSUE_TEMPLATE/bug_report.md

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ logs/
.coverage.*
.idea/
cluster_sbatch.sh
cluster_sbatch_mpi.sh
46 changes: 39 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,27 @@ Continue training (here, load pretrained agent for Breakout and continue trainin
python train.py --algo a2c --env BreakoutNoFrameskip-v4 -i trained_agents/a2c/BreakoutNoFrameskip-v4.pkl -n 5000
```

Note: when training TRPO, you have to use `mpirun` to enable multiprocessing:

```
mpirun -n 16 python train.py --algo trpo --env BreakoutNoFrameskip-v4
```

## Hyperparameter Optimization

We use [Optuna](https://optuna.org/) for optimizing the hyperparameters.

Note: hyperparameters search is only implemented for PPO2/A2C/SAC/TRPO/DDPG for now.
when using SuccessiveHalvingPruner ("halving"), you must specify `--n-jobs > 1`

Budget of 1000 trials with a maximum of 50000 steps:

```
python -m train.py --algo ppo2 --env MountainCar-v0 -n 50000 -optimize --n-trials 1000 --n-jobs 2 \
--sampler random --pruner median
```


## Record a Video of a Trained Agent

Record 1000 steps:
Expand All @@ -61,7 +82,7 @@ python -m utils.record_video --algo ppo2 --env BipedalWalkerHardcore-v2 -n 1000
```


## Current Collection: 80+ Trained Agents!
## Current Collection: 100+ Trained Agents!

Scores can be found in `benchmark.md`. To compute them, simply run `python -m utils.benchmark`.

Expand All @@ -76,6 +97,8 @@ Scores can be found in `benchmark.md`. To compute them, simply run `python -m ut
| ACKTR |:heavy_check_mark:| :heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:| :heavy_check_mark:| :heavy_check_mark:| :heavy_check_mark: |
| PPO2 |:heavy_check_mark:|:heavy_check_mark:| :heavy_check_mark: |:heavy_check_mark: |:heavy_check_mark:|:heavy_check_mark:| :heavy_check_mark: |
| DQN |:heavy_check_mark:| :heavy_check_mark: |:heavy_check_mark:| :heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|:heavy_check_mark:|
| TRPO | | | | | | | |


Additional Atari Games (to be completed):

Expand All @@ -91,26 +114,28 @@ Additional Atari Games (to be completed):

| RL Algo | CartPole-v1 | MountainCar-v0 | Acrobot-v1 | Pendulum-v0 | MountainCarContinuous-v0 |
|----------|--------------|----------------|------------|--------------|--------------------------|
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| ACER | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | N/A | N/A |
| ACKTR | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | N/A | N/A |
| PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |:heavy_check_mark: |
| DQN | :heavy_check_mark: | :heavy_check_mark: |:heavy_check_mark:| N/A | N/A |
| DQN | :heavy_check_mark: | :heavy_check_mark: |:heavy_check_mark:| N/A | N/A |
| DDPG | N/A | N/A | N/A| :heavy_check_mark: | :heavy_check_mark: |
| SAC | N/A | N/A | N/A| :heavy_check_mark: | |
| SAC | N/A | N/A | N/A| :heavy_check_mark: | |
| TRPO | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: |


### Box2D Environments

| RL Algo | BipedalWalker-v2 | LunarLander-v2 | LunarLanderContinuous-v2 | BipedalWalkerHardcore-v2 | CarRacing-v0 |
|----------|--------------|----------------|------------|--------------|--------------------------|
| A2C | | :heavy_check_mark: | | | |
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | |
| ACER | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| ACKTR | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |:heavy_check_mark:| |
| DQN | N/A | :heavy_check_mark: | N/A | N/A | N/A |
| DDPG | | N/A | :heavy_check_mark: | | |
| SAC | :heavy_check_mark: | N/A | :heavy_check_mark: |:heavy_check_mark: | |
| TRPO | | :heavy_check_mark: | :heavy_check_mark: | | |

### PyBullet Environments

Expand All @@ -121,17 +146,21 @@ Note: those environments are derived from [Roboschool](https://github.com/openai

| RL Algo | Walker2D | HalfCheetah | Ant | Reacher | Hopper | Humanoid |
|----------|-----------|-------------|-----|---------|---------|----------|
| A2C | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | |
| PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |:heavy_check_mark:|
| DDPG | | | | | | | |
| DDPG | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | |
| SAC | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
| TRPO | | :heavy_check_mark: | | | | |

PyBullet Envs (Continued)

| RL Algo | Minitaur | MinitaurDuck | InvertedDoublePendulum | InvertedPendulumSwingup |
|----------|-----------|-------------|-----|---------|
| A2C | | | | |
| PPO2 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |:heavy_check_mark:|
| DDPG | | | | |
| SAC | | | :heavy_check_mark: | :heavy_check_mark: |
| TRPO | | | | |

## Colab Notebook: Try it Online!

Expand All @@ -140,9 +169,12 @@ You can train agents online using [colab notebook](https://colab.research.google
## Installation

### Stable-Baselines PyPi Package

Min version: stable-baselines >= 2.5.1

```
apt-get install swig cmake libopenmpi-dev zlib1g-dev ffmpeg
pip install stable-baselines==2.4.0 box2d box2d-kengz pyyaml pybullet==2.1.0 pytablewriter
pip install stable-baselines box2d box2d-kengz pyyaml pybullet optuna pytablewriter
```

Please see [Stable Baselines README](https://github.com/hill-a/stable-baselines) for alternatives.
Expand Down
19 changes: 19 additions & 0 deletions benchmark.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
|algo | env_id |mean_reward|std_reward|n_timesteps|n_episodes|
|-----|-----------------------------------|----------:|---------:|----------:|---------:|
|a2c |Acrobot-v1 | -86.616| 25.097| 149997| 1712|
|a2c |AntBulletEnv-v0 | 2271.322| 160.233| 150000| 150|
|a2c |BeamRiderNoFrameskip-v4 | 2809.115| 1298.573| 150181| 52|
|a2c |BipedalWalker-v2 | 255.012| 71.426| 149890| 169|
|a2c |BipedalWalkerHardcore-v2 | 102.754| 136.304| 149643| 137|
|a2c |BreakoutNoFrameskip-v4 | 384.865| 51.231| 146703| 52|
|a2c |CartPole-v1 | 499.903| 1.672| 149971| 300|
|a2c |EnduroNoFrameskip-v4 | 0.000| 0.000| 149574| 45|
|a2c |HalfCheetahBulletEnv-v0 | 2069.627| 103.895| 150000| 150|
|a2c |HopperBulletEnv-v0 | 1575.871| 669.267| 149075| 189|
|a2c |LunarLander-v2 | 36.321| 135.294| 149696| 463|
|a2c |LunarLanderContinuous-v2 | 203.912| 59.265| 149938| 253|
|a2c |MountainCar-v0 | -130.921| 32.188| 149904| 1145|
|a2c |MountainCarContinuous-v0 | 93.659| 0.199| 149985| 2187|
|a2c |MsPacmanNoFrameskip-v4 | 1581.111| 499.757| 150229| 189|
|a2c |Pendulum-v0 | -162.240| 99.351| 150000| 750|
|a2c |PongNoFrameskip-v4 | 18.973| 2.135| 148288| 75|
|a2c |QbertNoFrameskip-v4 | 5742.333| 2033.074| 151311| 150|
|a2c |SeaquestNoFrameskip-v4 | 746.420| 111.370| 149749| 81|
|a2c |SpaceInvadersNoFrameskip-v4 | 658.907| 197.833| 149846| 151|
|a2c |Walker2DBulletEnv-v0 | 618.318| 291.293| 149234| 187|
|acer |Acrobot-v1 | -90.850| 32.797| 149989| 1633|
|acer |BeamRiderNoFrameskip-v4 | 2440.692| 1357.964| 149127| 52|
|acer |CartPole-v1 | 498.620| 23.862| 149586| 300|
Expand All @@ -35,9 +44,12 @@
|acktr|QbertNoFrameskip-v4 | 9569.575| 3980.468| 150896| 106|
|acktr|SeaquestNoFrameskip-v4 | 1672.239| 105.092| 149148| 67|
|acktr|SpaceInvadersNoFrameskip-v4 | 738.045| 306.756| 149714| 156|
|ddpg |AntBulletEnv-v0 | 2070.790| 74.253| 150000| 150|
|ddpg |HalfCheetahBulletEnv-v0 | 2549.452| 37.652| 150000| 150|
|ddpg |LunarLanderContinuous-v2 | 244.566| 75.617| 149531| 660|
|ddpg |MountainCarContinuous-v0 | 91.858| 1.350| 149945| 1818|
|ddpg |Pendulum-v0 | -169.829| 93.303| 150000| 750|
|ddpg |Walker2DBulletEnv-v0 | 1954.753| 368.613| 149152| 155|
|dqn |Acrobot-v1 | -88.103| 33.037| 149954| 1683|
|dqn |BeamRiderNoFrameskip-v4 | 888.741| 248.487| 149395| 81|
|dqn |BreakoutNoFrameskip-v4 | 191.165| 97.795| 149817| 97|
Expand Down Expand Up @@ -89,3 +101,10 @@
|sac |Pendulum-v0 | -159.669| 86.665| 150000| 750|
|sac |ReacherBulletEnv-v0 | 17.529| 9.860| 150000| 1000|
|sac |Walker2DBulletEnv-v0 | 2052.646| 13.631| 150000| 150|
|trpo |CartPole-v1 | 485.392| 70.505| 149986| 309|
|trpo |HalfCheetahBulletEnv-v0 | 1850.967| 282.093| 150000| 150|
|trpo |LunarLander-v2 | 149.313| 108.546| 149893| 320|
|trpo |LunarLanderContinuous-v2 | 64.619| 94.377| 149127| 181|
|trpo |MountainCar-v0 | -144.537| 33.584| 149885| 1037|
|trpo |MountainCarContinuous-v0 | 93.428| 1.509| 149998| 1067|
|trpo |Pendulum-v0 | -176.951| 97.098| 150000| 750|
5 changes: 3 additions & 2 deletions docker/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ RUN \
pip install pytest-cov && \
pip install pyyaml && \
pip install box2d-py==2.3.5 && \
pip install stable-baselines==2.4.0 && \
pip install pybullet==2.1.0 && \
pip install stable-baselines && \
pip install pybullet && \
pip install optuna && \
pip install pytablewriter==0.36.0


Expand Down
5 changes: 3 additions & 2 deletions docker/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ RUN \
pip install pytest-cov && \
pip install pyyaml && \
pip install box2d-py==2.3.5 && \
pip install stable-baselines==2.4.0 && \
pip install tensorflow-gpu==1.8.0 && \
pip install pybullet==2.1.0 && \
pip install stable-baselines && \
pip install pybullet && \
pip install optuna && \
pip install pytablewriter==0.36.0

ENV PATH=$VENV/bin:$PATH
Expand Down
14 changes: 12 additions & 2 deletions enjoy.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import argparse
import os
import warnings
import sys
import pkg_resources

# For pybullet envs
warnings.filterwarnings("ignore")
import gym
import pybullet_envs
import numpy as np

import stable_baselines
from stable_baselines.common import set_global_seeds
from stable_baselines.common.vec_env import VecNormalize, VecFrameStack


from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams

# Fix for breaking change in v2.6.0
if pkg_resources.get_distribution("stable_baselines").version >= "2.6.0":
sys.modules['stable_baselines.ddpg.memory'] = stable_baselines.deepq.replay_buffer
stable_baselines.deepq.replay_buffer.Memory = stable_baselines.deepq.replay_buffer.ReplayBuffer

def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -52,7 +62,6 @@ def main():

model_path = "{}/{}.pkl".format(log_path, env_id)


assert os.path.isdir(log_path), "The {} folder was not found".format(log_path)
assert os.path.isfile(model_path), "No model found for {} on {}, path: {}".format(algo, env_id, model_path)

Expand Down Expand Up @@ -127,5 +136,6 @@ def main():
# SubprocVecEnv
env.close()


if __name__ == '__main__':
main()
111 changes: 108 additions & 3 deletions hyperparams/a2c.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ CartPole-v1:
ent_coef: 0.0

LunarLander-v2:
n_envs: 16
n_timesteps: !!float 2e6
n_envs: 8
n_timesteps: !!float 2e5
policy: 'MlpPolicy'
gamma: 0.999
gamma: 0.995
n_steps: 5
learning_rate: 0.00083
lr_schedule: 'linear'
ent_coef: 0.00001

MountainCar-v0:
normalize: true
Expand All @@ -29,3 +33,104 @@ Acrobot-v1:
n_timesteps: !!float 5e5
policy: 'MlpPolicy'
ent_coef: .0

Pendulum-v0:
n_envs: 8
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
ent_coef: 0.0
gamma: 0.95

LunarLanderContinuous-v2:
normalize: true
n_envs: 16
n_timesteps: !!float 5e6
policy: 'MlpPolicy'
gamma: 0.999
ent_coef: 0.001
lr_schedule: 'linear'

MountainCarContinuous-v0:
normalize: true
n_envs: 16
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
ent_coef: 0.0

BipedalWalker-v2:
normalize: true
n_envs: 16
n_timesteps: !!float 5e6
policy: 'MlpPolicy'
lr_schedule: 'linear'
ent_coef: 0.0

HalfCheetahBulletEnv-v0:
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
ent_coef: 0.0
n_steps: 32
vf_coef: 0.5
lr_schedule: 'linear'
gamma: 0.99
learning_rate: 0.002

BipedalWalkerHardcore-v2:
normalize: true
n_envs: 16
n_timesteps: !!float 10e7
policy: 'MlpPolicy'
frame_stack: 4
ent_coef: 0.001
lr_schedule: 'linear'

Walker2DBulletEnv-v0:
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
ent_coef: 0.0
n_steps: 32
vf_coef: 0.5
lr_schedule: 'linear'
gamma: 0.99
learning_rate: 0.002

AntBulletEnv-v0:
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
ent_coef: 0.0
n_steps: 32
vf_coef: 0.5
lr_schedule: 'linear'
gamma: 0.99
learning_rate: 0.002

HopperBulletEnv-v0:
normalize: true
n_envs: 4
n_timesteps: !!float 2e6
policy: 'MlpPolicy'
ent_coef: 0.0
n_steps: 32
vf_coef: 0.5
lr_schedule: 'linear'
gamma: 0.99
learning_rate: 0.002

# Not working yet
ReacherBulletEnv-v0:
normalize: true
n_envs: 8
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
ent_coef: 0.001
n_steps: 32
vf_coef: 0.5
lr_schedule: 'linear'
gamma: 0.99
learning_rate: 0.002
Loading

0 comments on commit 64267d3

Please sign in to comment.