Skip to content

Commit

Permalink
Merge pull request #6 from vwxyzjn/actor-threads
Browse files Browse the repository at this point in the history
Add actor threads example
  • Loading branch information
vwxyzjn authored Aug 17, 2023
2 parents 14f1dec + e11dc0a commit 6367b2c
Show file tree
Hide file tree
Showing 22 changed files with 10,872 additions and 388 deletions.
28 changes: 15 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Cleanba is CleanRL's implementation of DeepMind's Sebulba distributed training f
![](static/scalability.png)
![](static/cleanba_scaling_efficiency.png)

**Understandable**: We adopt the single-file implementation philosophy used in CleanRL, making our core codebase succinct and easy to understand. For example, our `cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py` is ~800 lines of code.
**Understandable**: We adopt the single-file implementation philosophy used in CleanRL, making our core codebase succinct and easy to understand. For example, our `cleanba/cleanba_ppo.py` is ~800 lines of code.



Expand All @@ -50,14 +50,16 @@ Cleanba is CleanRL's implementation of DeepMind's Sebulba distributed training f
Prerequisites:
* Python >=3.8
* [Poetry 1.3.2+](https://python-poetry.org)
* CUDA 11.2+
* CuDNN 8.2+


### Installation:
```
poetry install
poetry run pip install --upgrade "jax[cuda]==0.3.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py
poetry run python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --help
poetry run pip install --upgrade "jax[cuda11_cudnn82]==0.4.8" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry run python cleanba/cleanba_ppo.py
poetry run python cleanba/cleanba_ppo.py --help
```

### Experiments:
Expand All @@ -67,28 +69,28 @@ Here are come common setups. You can also run the commands with `--track` to tra

```
# a0-l0-d1: single GPU
python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --actor-device-ids 0 --learner-device-ids 0 --local-num-envs 120 --track
python cleanba/cleanba_ppo.py --actor-device-ids 0 --learner-device-ids 0 --local-num-envs 120 --track
# a0-l0,1-d1: two GPUs
python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --actor-device-ids 0 --learner-device-ids 0 1 --local-num-envs 120
python cleanba/cleanba_ppo.py --actor-device-ids 0 --learner-device-ids 0 1 --local-num-envs 120
# a0-l1,2-d1: three GPUs
python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --actor-device-ids 0 --learner-device-ids 1 2 --local-num-envs 120
python cleanba/cleanba_ppo.py --actor-device-ids 0 --learner-device-ids 1 2 --local-num-envs 120
# a0-l1,2,3-d1: four GPUs
python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --actor-device-ids 0 --learner-device-ids 1 2 3 --local-num-envs 120
python cleanba/cleanba_ppo.py --actor-device-ids 0 --learner-device-ids 1 2 3
# a0-l1,2,3,4-d1: five GPUs
python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --local-num-envs 120
python cleanba/cleanba_ppo.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --local-num-envs 120
# a0-l1,2,3,4,5,6-d1: seven GPUs
python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 --local-num-envs 120
python cleanba/cleanba_ppo.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 --local-num-envs 120
# a0-l0-d2: 8 GPUs (distributed 2 times on 4 GPUs)
# execute them in separate terminals; here we assume all 8 GPUs are on the same machine
# however it is possible to scale to hundreds of GPUs allowed by `jax.distributed`
CUDA_VISIBLE_DEVICES="0,1,2,3" SLURM_JOB_ID=26017 SLURM_STEP_NODELIST=localhost SLURM_NTASKS=2 SLURM_PROCID=0 SLURM_LOCALID=0 SLURM_STEP_NUM_NODES=2 python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --actor-device-ids 0 --learner-device-ids 1 2 3
CUDA_VISIBLE_DEVICES="4,5,6,7" SLURM_JOB_ID=26017 SLURM_STEP_NODELIST=localhost SLURM_NTASKS=2 SLURM_PROCID=1 SLURM_LOCALID=0 SLURM_STEP_NUM_NODES=2 python cleanba/cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --actor-device-ids 0 --learner-device-ids 1 2 3
CUDA_VISIBLE_DEVICES="0,1,2,3" SLURM_JOB_ID=26017 SLURM_STEP_NODELIST=localhost SLURM_NTASKS=2 SLURM_PROCID=0 SLURM_LOCALID=0 SLURM_STEP_NUM_NODES=2 python cleanba/cleanba_ppo.py --distributed --actor-device-ids 0 --learner-device-ids 1 2 3
CUDA_VISIBLE_DEVICES="4,5,6,7" SLURM_JOB_ID=26017 SLURM_STEP_NODELIST=localhost SLURM_NTASKS=2 SLURM_PROCID=1 SLURM_LOCALID=0 SLURM_STEP_NUM_NODES=2 python cleanba/cleanba_ppo.py --distributed --actor-device-ids 0 --learner-device-ids 1 2 3
# if you have slurm it's possible to run the following
python -m cleanrl_utils.benchmark \
--env-ids Breakout-v5 \
--command "poetry run python cleanrl/cleanba_ppo_envpool_impala_atari_wrapper_large.py --distributed --learner-device-ids 1 2 3 --track --save-model --upload-model" \
--command "poetry run python cleanrl/cleanba_ppo.py --distributed --learner-device-ids 1 2 3 --track --save-model --upload-model" \
--num-seeds 1 \
--workers 1 \
--slurm-gpus-per-task 4 \
Expand Down
Loading

0 comments on commit 6367b2c

Please sign in to comment.