Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #243

Merged
merged 37 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9afecf5
Implemented CrossQ
danielpalen May 3, 2024
4fa78a7
Fixed code style
danielpalen May 5, 2024
7ce57de
Clean up, comments and refactored to sbx variable names
danielpalen May 12, 2024
9c339b8
1024 neuron Q function (sbx default)
danielpalen May 12, 2024
2b1ff5e
batch norm parameters as function arguments
danielpalen May 12, 2024
aace2ac
clean up. reshape instead of split
danielpalen May 12, 2024
4df7111
Added policy delay
danielpalen May 12, 2024
5583225
fixed commit-checks
danielpalen May 12, 2024
567c2fb
Fix f-string
araffin May 13, 2024
8970ed0
Update documentation
araffin May 13, 2024
8792621
Rename to torch layers
araffin May 13, 2024
230a948
Fix for policy delay and minor edits
araffin May 13, 2024
cd8bd7d
Update tests
araffin May 13, 2024
27a96f6
Update documentation
araffin May 13, 2024
7d6c642
Merge branch 'master' into feat/crossq
araffin Jul 1, 2024
3927a70
Update doc
araffin Jul 6, 2024
2019327
Add more tests for crossQ
araffin Jul 6, 2024
b0213ec
Improve doc and expose batchnorm params
araffin Jul 6, 2024
9772ecf
Merge branch 'master' into feat/crossq
araffin Jul 6, 2024
454224d
Add some comments and todos and fix type check
araffin Jul 6, 2024
a7bbac9
Merge branch 'feat/crossq' of github.com:danielpalen/stable-baselines…
araffin Jul 6, 2024
bbd654c
Use torch module for BN
araffin Jul 19, 2024
bb80218
Re-organize losses
araffin Jul 19, 2024
a717d13
Add set_bn_training_mode
araffin Jul 19, 2024
cb1bc8f
Simplify network creation with new SB3 version, and fix default momentum
araffin Jul 20, 2024
a88a19b
Use different b1 for Adam as in original implementation
araffin Jul 20, 2024
32f66fe
Reformat TOML file
araffin Jul 20, 2024
03db09e
Update CI workflow, skip mypy for 3.8
araffin Jul 22, 2024
244b930
Merge branch 'master' into feat/crossq
araffin Aug 13, 2024
497ea7e
Merge branch 'master' into feat/crossq
araffin Oct 18, 2024
72abe85
Update CrossQ doc
araffin Oct 24, 2024
f1fc8f5
Use uv to download packages on github CI
araffin Oct 24, 2024
497f5fe
System install for Github CI
araffin Oct 24, 2024
6e37805
Fix for pytorch install
araffin Oct 24, 2024
94af853
Use +cpu version
araffin Oct 24, 2024
6cd924e
Pytorch 2.5.0 doesn't support python 3.8
araffin Oct 24, 2024
125a8ca
Update comments
araffin Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See documentation for the full list of included features.
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
Binary file added docs/images/crossQ_performance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/ars
modules/crossq
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
Expand Down
99 changes: 99 additions & 0 deletions docs/modules/crossq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
.. _crossq:

.. automodule:: sb3_contrib.crossq


CrossQ
======

Implementation of CrossQ proposed in:

`Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.`

CrossQ is a simple and efficient algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms.
It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks.
This yield a simpler and more sample-efficient algorithm without requiring high update-to-data ratios.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add at least the multi input policy? (so we can try it in combination with HER)
Only the feature extractor should be changed normally.

And what do you think about adding CnnPolicy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I looked into it and have not yet added it. If I am not mistaken this would also require some changes to the CrossQ train() function. Since, now concatenating and splitting the batches would also require some control flow based on the used policy.
For simplicity sake (for now) and since I did not have time to try and evaluate the multi input policy I did not add that yet.



Notes
-----

- Original paper: https://openreview.net/pdf?id=PczQtTsTIX
- Original Implementation: https://github.com/adityab/CrossQ


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== ===========


Example
-------

.. code-block:: python

import gymnasium as gym
import numpy as np

from sb3_contrib import CrossQ

model = CrossQ("MlpPolicy", "Walker2d-v4")
model.learn(total_timesteps=1_000_000)
model.save("crossq_walker")


Results
araffin marked this conversation as resolved.
Show resolved Hide resolved
-------

Performance evaluation of CrossQ on six MuJoCo environments.
Compared to results from the original paper as well as a version from SBX.

.. image:: ../images/crossQ_performance.png

Comments
--------

This implementation is based on SB3 SAC implementation.


Parameters
----------

.. autoclass:: CrossQ
:members:
:inherited-members:

.. _crossq_policies:

CrossQ Policies
---------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy
:members:
:noindex:

2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.crossq import CrossQ
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
94 changes: 94 additions & 0 deletions sb3_contrib/common/network_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch

__all__ = ["BatchRenorm1d"]


class BatchRenorm(torch.jit.ScriptModule):
"""
BatchRenorm Module (https://arxiv.org/abs/1702.03275).
Adapted from flax.linen.normalization.BatchNorm

BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm,
BatchRenorm uses the running statistics for normalizing the batches after a warmup phase.
This makes it less prone to suffer from "outlier" batches that can happen
during very long training runs and, therefore, is more robust during long training runs.

During the warmup phase, it behaves exactly like a BatchNorm layer.

Args:
num_features: Number of features in the input tensor.
eps: A value added to the variance for numerical stability.
momentum: The value used for the running_mean and running_var computation.
affine: A boolean value that when set to True, this module has learnable
affine parameters. Default: True
"""

def __init__(
self,
num_features: int,
eps: float = 0.001,
momentum: float = 0.01,
affine: bool = True,
):
super().__init__()
self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float))
self.register_buffer("running_var", torch.ones(num_features, dtype=torch.float))
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
self.scale = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))

self.affine = affine
self.eps = eps
self.step = 0
self.momentum = momentum
self.rmax = 3.0
self.dmax = 5.0

def _check_input_dim(self, x: torch.Tensor) -> None:
raise NotImplementedError()

def forward(self, x: torch.Tensor) -> torch.Tensor:

if self.training:
batch_mean = x.mean(0)
batch_var = x.var(0)
batch_std = (batch_var + self.eps).sqrt()

# Use batch statistics during initial warm up phase.
araffin marked this conversation as resolved.
Show resolved Hide resolved
if self.num_batches_tracked > 100_000:

running_std = (self.running_var + self.eps).sqrt()
running_mean = self.running_mean

r = (batch_std / running_std).detach()
r = r.clamp(1 / self.rmax, self.rmax)
d = ((batch_mean - running_mean) / running_std).detach()
d = d.clamp(-self.dmax, self.dmax)

m = batch_mean - d * batch_var.sqrt() / r
v = batch_var / (r**2)

else:
m, v = batch_mean, batch_var

# Update Running Statistics
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean)
self.running_var += self.momentum * (batch_var.detach() - self.running_var)
self.num_batches_tracked += 1

else:
m, v = self.running_mean, self.running_var

# Normalize
x = (x - m[None]) / (v[None] + self.eps).sqrt()

if self.affine:
x = self.scale * x + self.bias

return x


class BatchRenorm1d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() == 1:
raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")
4 changes: 4 additions & 0 deletions sb3_contrib/crossq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.crossq.crossq import CrossQ
from sb3_contrib.crossq.policies import MlpPolicy

__all__ = ["CrossQ", "MlpPolicy"]
Loading
Loading