-
Notifications
You must be signed in to change notification settings - Fork 175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented CrossQ #243
Merged
araffin
merged 37 commits into
Stable-Baselines-Team:master
from
danielpalen:feat/crossq
Oct 24, 2024
Merged
Implemented CrossQ #243
Changes from 2 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
9afecf5
Implemented CrossQ
danielpalen 4fa78a7
Fixed code style
danielpalen 7ce57de
Clean up, comments and refactored to sbx variable names
danielpalen 9c339b8
1024 neuron Q function (sbx default)
danielpalen 2b1ff5e
batch norm parameters as function arguments
danielpalen aace2ac
clean up. reshape instead of split
danielpalen 4df7111
Added policy delay
danielpalen 5583225
fixed commit-checks
danielpalen 567c2fb
Fix f-string
araffin 8970ed0
Update documentation
araffin 8792621
Rename to torch layers
araffin 230a948
Fix for policy delay and minor edits
araffin cd8bd7d
Update tests
araffin 27a96f6
Update documentation
araffin 7d6c642
Merge branch 'master' into feat/crossq
araffin 3927a70
Update doc
araffin 2019327
Add more tests for crossQ
araffin b0213ec
Improve doc and expose batchnorm params
araffin 9772ecf
Merge branch 'master' into feat/crossq
araffin 454224d
Add some comments and todos and fix type check
araffin a7bbac9
Merge branch 'feat/crossq' of github.com:danielpalen/stable-baselines…
araffin bbd654c
Use torch module for BN
araffin bb80218
Re-organize losses
araffin a717d13
Add set_bn_training_mode
araffin cb1bc8f
Simplify network creation with new SB3 version, and fix default momentum
araffin a88a19b
Use different b1 for Adam as in original implementation
araffin 32f66fe
Reformat TOML file
araffin 03db09e
Update CI workflow, skip mypy for 3.8
araffin 244b930
Merge branch 'master' into feat/crossq
araffin 497ea7e
Merge branch 'master' into feat/crossq
araffin 72abe85
Update CrossQ doc
araffin f1fc8f5
Use uv to download packages on github CI
araffin 497f5fe
System install for Github CI
araffin 6e37805
Fix for pytorch install
araffin 94af853
Use +cpu version
araffin 6cd924e
Pytorch 2.5.0 doesn't support python 3.8
araffin 125a8ca
Update comments
araffin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
.. _crossq: | ||
|
||
.. automodule:: sb3_contrib.crossq | ||
|
||
|
||
CrossQ | ||
====== | ||
|
||
Implementation of CrossQ proposed in: | ||
|
||
`Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.` | ||
|
||
CrossQ is a simple and efficient algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms. | ||
It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks. | ||
This yield a simpler and more sample-efficient algorithm without requiring high update-to-data ratios. | ||
|
||
.. rubric:: Available Policies | ||
|
||
.. autosummary:: | ||
:nosignatures: | ||
|
||
MlpPolicy | ||
|
||
|
||
Notes | ||
----- | ||
|
||
- Original paper: https://openreview.net/pdf?id=PczQtTsTIX | ||
- Original Implementation: https://github.com/adityab/CrossQ | ||
|
||
|
||
Can I use? | ||
---------- | ||
|
||
- Recurrent policies: ❌ | ||
- Multi processing: ✔️ | ||
- Gym spaces: | ||
|
||
|
||
============= ====== =========== | ||
Space Action Observation | ||
============= ====== =========== | ||
Discrete ❌ ✔️ | ||
Box ✔️ ✔️ | ||
MultiDiscrete ❌ ✔️ | ||
MultiBinary ❌ ✔️ | ||
Dict ❌ ✔️ | ||
============= ====== =========== | ||
|
||
|
||
Example | ||
------- | ||
|
||
.. code-block:: python | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
|
||
from sb3_contrib import CrossQ | ||
|
||
model = CrossQ("MlpPolicy", "Walker2d-v4") | ||
model.learn(total_timesteps=1_000_000) | ||
model.save("crossq_walker") | ||
|
||
|
||
Results | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
------- | ||
|
||
Performance evaluation of CrossQ on six MuJoCo environments. | ||
Compared to results from the original paper as well as a version from SBX. | ||
|
||
.. image:: ../images/crossQ_performance.png | ||
|
||
Comments | ||
-------- | ||
|
||
This implementation is based on SB3 SAC implementation. | ||
|
||
|
||
Parameters | ||
---------- | ||
|
||
.. autoclass:: CrossQ | ||
:members: | ||
:inherited-members: | ||
|
||
.. _crossq_policies: | ||
|
||
CrossQ Policies | ||
--------------- | ||
|
||
.. autoclass:: MlpPolicy | ||
:members: | ||
:inherited-members: | ||
|
||
.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy | ||
:members: | ||
:noindex: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
|
||
__all__ = ["BatchRenorm1d"] | ||
|
||
|
||
class BatchRenorm(torch.jit.ScriptModule): | ||
""" | ||
BatchRenorm Module (https://arxiv.org/abs/1702.03275). | ||
Adapted from flax.linen.normalization.BatchNorm | ||
|
||
BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm, | ||
BatchRenorm uses the running statistics for normalizing the batches after a warmup phase. | ||
This makes it less prone to suffer from "outlier" batches that can happen | ||
during very long training runs and, therefore, is more robust during long training runs. | ||
|
||
During the warmup phase, it behaves exactly like a BatchNorm layer. | ||
|
||
Args: | ||
num_features: Number of features in the input tensor. | ||
eps: A value added to the variance for numerical stability. | ||
momentum: The value used for the running_mean and running_var computation. | ||
affine: A boolean value that when set to True, this module has learnable | ||
affine parameters. Default: True | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_features: int, | ||
eps: float = 0.001, | ||
momentum: float = 0.01, | ||
affine: bool = True, | ||
): | ||
super().__init__() | ||
self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float)) | ||
self.register_buffer("running_var", torch.ones(num_features, dtype=torch.float)) | ||
self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) | ||
self.scale = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float)) | ||
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float)) | ||
|
||
self.affine = affine | ||
self.eps = eps | ||
self.step = 0 | ||
self.momentum = momentum | ||
self.rmax = 3.0 | ||
self.dmax = 5.0 | ||
|
||
def _check_input_dim(self, x: torch.Tensor) -> None: | ||
raise NotImplementedError() | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
|
||
if self.training: | ||
batch_mean = x.mean(0) | ||
batch_var = x.var(0) | ||
batch_std = (batch_var + self.eps).sqrt() | ||
|
||
# Use batch statistics during initial warm up phase. | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.num_batches_tracked > 100_000: | ||
|
||
running_std = (self.running_var + self.eps).sqrt() | ||
running_mean = self.running_mean | ||
|
||
r = (batch_std / running_std).detach() | ||
r = r.clamp(1 / self.rmax, self.rmax) | ||
d = ((batch_mean - running_mean) / running_std).detach() | ||
d = d.clamp(-self.dmax, self.dmax) | ||
|
||
m = batch_mean - d * batch_var.sqrt() / r | ||
v = batch_var / (r**2) | ||
|
||
else: | ||
m, v = batch_mean, batch_var | ||
|
||
# Update Running Statistics | ||
self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean) | ||
self.running_var += self.momentum * (batch_var.detach() - self.running_var) | ||
self.num_batches_tracked += 1 | ||
|
||
else: | ||
m, v = self.running_mean, self.running_var | ||
|
||
# Normalize | ||
x = (x - m[None]) / (v[None] + self.eps).sqrt() | ||
|
||
if self.affine: | ||
x = self.scale * x + self.bias | ||
|
||
return x | ||
|
||
|
||
class BatchRenorm1d(BatchRenorm): | ||
def _check_input_dim(self, x: torch.Tensor) -> None: | ||
if x.dim() == 1: | ||
raise ValueError("expected 2D or 3D input (got {x.dim()}D input)") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from sb3_contrib.crossq.crossq import CrossQ | ||
from sb3_contrib.crossq.policies import MlpPolicy | ||
|
||
__all__ = ["CrossQ", "MlpPolicy"] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add at least the multi input policy? (so we can try it in combination with HER)
Only the feature extractor should be changed normally.
And what do you think about adding
CnnPolicy
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point. I looked into it and have not yet added it. If I am not mistaken this would also require some changes to the CrossQ train() function. Since, now concatenating and splitting the batches would also require some control flow based on the used policy.
For simplicity sake (for now) and since I did not have time to try and evaluate the multi input policy I did not add that yet.