Skip to content

Commit

Permalink
Update QRDQN defaults (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Jan 12, 2024
1 parent 9f333ff commit 588c6bd
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 4 deletions.
32 changes: 32 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,38 @@
Changelog
==========

Release 2.3.0a1 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.3.0
- The default ``leanrning_starts`` parameter of ``QRDQN`` have been changed to be consistent with the other offpolicy algorithms


.. code-block:: python
# SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters
# model = QRDQN("MlpPolicy", env, learning_start=50_000)
# SB3 >= 2.3.0:
model = QRDQN("MlpPolicy", env, learning_start=100)
New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 2.2.1 (2023-11-17)
--------------------------
Expand Down
5 changes: 3 additions & 2 deletions sb3_contrib/qrdqn/qrdqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ class QRDQN(OffPolicyAlgorithm):
"""
Quantile Regression Deep Q-Network (QR-DQN)
Paper: https://arxiv.org/abs/1710.10044
Default hyperparameters are taken from the paper and are tuned for Atari games.
Default hyperparameters are taken from the paper and are tuned for Atari games
(except for the ``learning_starts`` parameter).
:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
Expand Down Expand Up @@ -77,7 +78,7 @@ def __init__(
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
learning_starts: int = 100,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.1
2.3.0a1
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.2.1,<3.0",
"stable_baselines3>=2.3.0a0,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down

0 comments on commit 588c6bd

Please sign in to comment.