Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #243

Merged
merged 37 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9afecf5
Implemented CrossQ
danielpalen May 3, 2024
4fa78a7
Fixed code style
danielpalen May 5, 2024
7ce57de
Clean up, comments and refactored to sbx variable names
danielpalen May 12, 2024
9c339b8
1024 neuron Q function (sbx default)
danielpalen May 12, 2024
2b1ff5e
batch norm parameters as function arguments
danielpalen May 12, 2024
aace2ac
clean up. reshape instead of split
danielpalen May 12, 2024
4df7111
Added policy delay
danielpalen May 12, 2024
5583225
fixed commit-checks
danielpalen May 12, 2024
567c2fb
Fix f-string
araffin May 13, 2024
8970ed0
Update documentation
araffin May 13, 2024
8792621
Rename to torch layers
araffin May 13, 2024
230a948
Fix for policy delay and minor edits
araffin May 13, 2024
cd8bd7d
Update tests
araffin May 13, 2024
27a96f6
Update documentation
araffin May 13, 2024
7d6c642
Merge branch 'master' into feat/crossq
araffin Jul 1, 2024
3927a70
Update doc
araffin Jul 6, 2024
2019327
Add more tests for crossQ
araffin Jul 6, 2024
b0213ec
Improve doc and expose batchnorm params
araffin Jul 6, 2024
9772ecf
Merge branch 'master' into feat/crossq
araffin Jul 6, 2024
454224d
Add some comments and todos and fix type check
araffin Jul 6, 2024
a7bbac9
Merge branch 'feat/crossq' of github.com:danielpalen/stable-baselines…
araffin Jul 6, 2024
bbd654c
Use torch module for BN
araffin Jul 19, 2024
bb80218
Re-organize losses
araffin Jul 19, 2024
a717d13
Add set_bn_training_mode
araffin Jul 19, 2024
cb1bc8f
Simplify network creation with new SB3 version, and fix default momentum
araffin Jul 20, 2024
a88a19b
Use different b1 for Adam as in original implementation
araffin Jul 20, 2024
32f66fe
Reformat TOML file
araffin Jul 20, 2024
03db09e
Update CI workflow, skip mypy for 3.8
araffin Jul 22, 2024
244b930
Merge branch 'master' into feat/crossq
araffin Aug 13, 2024
497ea7e
Merge branch 'master' into feat/crossq
araffin Oct 18, 2024
72abe85
Update CrossQ doc
araffin Oct 24, 2024
f1fc8f5
Use uv to download packages on github CI
araffin Oct 24, 2024
497f5fe
System install for Github CI
araffin Oct 24, 2024
6e37805
Fix for pytorch install
araffin Oct 24, 2024
94af853
Use +cpu version
araffin Oct 24, 2024
6cd924e
Pytorch 2.5.0 doesn't support python 3.8
araffin Oct 24, 2024
125a8ca
Update comments
araffin Oct 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,24 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Use uv for faster downloads
pip install uv
# cpu version of pytorch
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu

# Install Atari Roms
pip install autorom
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

# Install master version
# and dependencies for docs and tests
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
pip install .
uv pip install --system "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
uv pip install --system .
# Use headless version
pip install opencv-python-headless
uv pip install --system opencv-python-headless

- name: Lint with ruff
run: |
Expand All @@ -58,6 +61,8 @@ jobs:
- name: Type check
run: |
make type
# Do not run for python 3.8 (mypy internal error)
if: matrix.python-version != '3.8'
- name: Test with pytest
run: |
make pytest
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See documentation for the full list of included features.
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
7 changes: 7 additions & 0 deletions docs/common/torch_layers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _th_layers:

Torch Layers
============

.. automodule:: sb3_contrib.common.torch_layers
:members:
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Pr
============ =========== ============ ================= =============== ================
ARS ✔️ ❌️ ❌ ❌ ✔️
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
CrossQ ✔️ ❌ ❌ ❌ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
TQC ✔️ ❌ ❌ ❌ ✔️
Expand Down
23 changes: 23 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,26 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
vec_env.render("human")

CrossQ
------

Train a CrossQ agent on the Pendulum environment.

.. code-block:: python

from sb3_contrib import CrossQ

model = CrossQ(
"MlpPolicy",
"Pendulum-v1",
verbose=1,
policy_kwargs=dict(
net_arch=dict(
pi=[256, 256],
qf=[1024, 1024],
)
),
)
model.learn(total_timesteps=5_000, log_interval=4)
model.save("crossq_pendulum")
Binary file added docs/images/crossQ_performance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/ars
modules/crossq
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
Expand All @@ -42,6 +43,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:maxdepth: 1
:caption: Common

common/torch_layers
common/utils
common/wrappers

Expand Down
10 changes: 7 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
Changelog
==========


Release 2.4.0a9 (WIP)
Release 2.4.0a10 (WIP)
--------------------------

**New algorithm: added CrossQ**

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.4.0

New Features:
^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)

Bug Fixes:
^^^^^^^^^^
Expand All @@ -28,6 +31,7 @@ Others:
^^^^^^^
- Updated PyTorch version on CI to 2.3.1
- Remove unnecessary SDE noise resampling in PPO/TRPO update
- Switched to uv to download packages on GitHub CI

Documentation:
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -584,4 +588,4 @@ Contributors:
-------------

@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @corentinlger
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen @corentinlger
134 changes: 134 additions & 0 deletions docs/modules/crossq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
.. _crossq:

.. automodule:: sb3_contrib.crossq


CrossQ
======

Implementation of CrossQ proposed in:

`Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.`

CrossQ is an algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms.
It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks.
This results in a simpler and more sample-efficient algorithm without requiring high update-to-data ratios.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add at least the multi input policy? (so we can try it in combination with HER)
Only the feature extractor should be changed normally.

And what do you think about adding CnnPolicy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I looked into it and have not yet added it. If I am not mistaken this would also require some changes to the CrossQ train() function. Since, now concatenating and splitting the batches would also require some control flow based on the used policy.
For simplicity sake (for now) and since I did not have time to try and evaluate the multi input policy I did not add that yet.


.. note::

Compared to the original implementation, the default network architecture for the q-value function is ``[1024, 1024]``
instead of ``[2048, 2048]`` as it provides a good compromise between speed and performance.

.. note::

There is currently no ``CnnPolicy`` for using CrossQ with images. We welcome help from contributors to add this feature.


Notes
-----

- Original paper: https://openreview.net/pdf?id=PczQtTsTIX
- Original Implementation: https://github.com/adityab/CrossQ
- SBX (SB3 Jax) Implementation: https://github.com/araffin/sbx


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ❌
============= ====== ===========


Example
-------

.. code-block:: python

from sb3_contrib import CrossQ

model = CrossQ("MlpPolicy", "Walker2d-v4")
model.learn(total_timesteps=1_000_000)
model.save("crossq_walker")


Results
araffin marked this conversation as resolved.
Show resolved Hide resolved
-------

Performance evaluation of CrossQ on six MuJoCo environments, see `PR #243 <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/243>`_.
Compared to results from the original paper as well as a version from `SBX <https://github.com/araffin/sbx>`_.

.. image:: ../images/crossQ_performance.png


Open RL benchmark report: https://wandb.ai/openrlbenchmark/sb3-contrib/reports/SB3-Contrib-CrossQ--Vmlldzo4NTE2MTEx


How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone RL-Zoo:

.. code-block:: bash

git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/

Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):

.. code-block:: bash

python train.py --algo crossq --env $ENV_ID --n-eval-envs 5 --eval-episodes 20 --eval-freq 25000


Plot the results:

.. code-block:: bash

python scripts/all_plots.py -a crossq -e HalfCheetah Ant Hopper Walker2D -f logs/ -o logs/crossq_results
python scripts/plot_from_file.py -i logs/crossq_results.pkl -latex -l CrossQ


Comments
--------

This implementation is based on SB3 SAC implementation.


Parameters
----------

.. autoclass:: CrossQ
:members:
:inherited-members:

.. _crossq_policies:

CrossQ Policies
---------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy
:members:
:noindex:
1 change: 0 additions & 1 deletion docs/modules/ppo_recurrent.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ Clone the repo for the experiment:

git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo
git checkout feat/recurrent-ppo


Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,6 @@ exclude_lines = [
"raise NotImplementedError()",
"if typing.TYPE_CHECKING:",
]

# [tool.pyright]
# extraPaths = ["../torchy-baselines/"]
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.crossq import CrossQ
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
Loading
Loading