From 588c6bdaeaa118a075162eddcd77c753d880bee2 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jan 2024 16:17:44 +0100 Subject: [PATCH] Update QRDQN defaults (#225) --- docs/misc/changelog.rst | 32 ++++++++++++++++++++++++++++++++ sb3_contrib/qrdqn/qrdqn.py | 5 +++-- sb3_contrib/version.txt | 2 +- setup.py | 2 +- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index a46f2f59..ba3bb602 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,38 @@ Changelog ========== +Release 2.3.0a1 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 2.3.0 +- The default ``leanrning_starts`` parameter of ``QRDQN`` have been changed to be consistent with the other offpolicy algorithms + + +.. code-block:: python + + # SB3 < 2.3.0 default hyperparameters, 50_000 corresponded to Atari defaults hyperparameters + # model = QRDQN("MlpPolicy", env, learning_start=50_000) + # SB3 >= 2.3.0: + model = QRDQN("MlpPolicy", env, learning_start=100) + + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ + Release 2.2.1 (2023-11-17) -------------------------- diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index c7dbacd5..303862e6 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -20,7 +20,8 @@ class QRDQN(OffPolicyAlgorithm): """ Quantile Regression Deep Q-Network (QR-DQN) Paper: https://arxiv.org/abs/1710.10044 - Default hyperparameters are taken from the paper and are tuned for Atari games. + Default hyperparameters are taken from the paper and are tuned for Atari games + (except for the ``learning_starts`` parameter). :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) @@ -77,7 +78,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 5e-5, buffer_size: int = 1000000, # 1e6 - learning_starts: int = 50000, + learning_starts: int = 100, batch_size: int = 32, tau: float = 1.0, gamma: float = 0.99, diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index c043eea7..4d04ad95 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.2.1 +2.3.0a1 diff --git a/setup.py b/setup.py index e3cac102..1dd98c9f 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=2.2.1,<3.0", + "stable_baselines3>=2.3.0a0,<3.0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin",