From 27de67c31258ac4c4cfb5ad90938a44652c06cb1 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 23 May 2024 14:06:43 +0200 Subject: [PATCH] Fix warning and remove DroQ class in favor of SAC config (#47) * Fix numpy warning when computing terminal value * Remove deprecated DroQ, use special config of SAC instead --- sbx/__init__.py | 11 +++- sbx/common/on_policy_algorithm.py | 3 +- sbx/droq/__init__.py | 3 -- sbx/droq/droq.py | 86 ------------------------------- sbx/version.txt | 2 +- tests/test_run.py | 37 ++++++------- 6 files changed, 30 insertions(+), 112 deletions(-) delete mode 100644 sbx/droq/__init__.py delete mode 100644 sbx/droq/droq.py diff --git a/sbx/__init__.py b/sbx/__init__.py index 70c06bc..a7c13bc 100644 --- a/sbx/__init__.py +++ b/sbx/__init__.py @@ -3,7 +3,6 @@ from sbx.crossq import CrossQ from sbx.ddpg import DDPG from sbx.dqn import DQN -from sbx.droq import DroQ from sbx.ppo import PPO from sbx.sac import SAC from sbx.td3 import TD3 @@ -14,11 +13,19 @@ with open(version_file) as file_handler: __version__ = file_handler.read().strip() + +def DroQ(*args, **kwargs): + raise ImportError( + "Since SBX 0.16.0, `DroQ` is now a special configuration of SAC.\n " + "Please check the documentation for more information: " + "https://github.com/araffin/sbx?tab=readme-ov-file#note-about-droq" + ) + + __all__ = [ "CrossQ", "DDPG", "DQN", - "DroQ", "PPO", "SAC", "TD3", diff --git a/sbx/common/on_policy_algorithm.py b/sbx/common/on_policy_algorithm.py index 4b75411..015fdc2 100644 --- a/sbx/common/on_policy_algorithm.py +++ b/sbx/common/on_policy_algorithm.py @@ -179,8 +179,7 @@ def collect_rollouts( self.policy.vf_state.params, terminal_obs, ).flatten() - ) - + ).item() rewards[idx] += self.gamma * terminal_value rollout_buffer.add( diff --git a/sbx/droq/__init__.py b/sbx/droq/__init__.py deleted file mode 100644 index 0328bd5..0000000 --- a/sbx/droq/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from sbx.droq.droq import DroQ - -__all__ = ["DroQ"] diff --git a/sbx/droq/droq.py b/sbx/droq/droq.py deleted file mode 100644 index 4d54ba3..0000000 --- a/sbx/droq/droq.py +++ /dev/null @@ -1,86 +0,0 @@ -import warnings -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union - -from stable_baselines3.common.buffers import ReplayBuffer -from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.type_aliases import GymEnv, Schedule - -from sbx.tqc.policies import TQCPolicy -from sbx.tqc.tqc import TQC - - -class DroQ(TQC): - policy_aliases: ClassVar[Dict[str, Type[TQCPolicy]]] = { - "MlpPolicy": TQCPolicy, - } - - def __init__( - self, - policy, - env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, - qf_learning_rate: Optional[float] = None, - buffer_size: int = 1_000_000, # 1e6 - learning_starts: int = 100, - batch_size: int = 256, - tau: float = 0.005, - gamma: float = 0.99, - train_freq: Union[int, Tuple[int, str]] = 1, - gradient_steps: int = 2, - # policy_delay = gradient_steps to follow original implementation - policy_delay: int = 2, - top_quantiles_to_drop_per_net: int = 2, - dropout_rate: float = 0.01, - layer_norm: bool = True, - action_noise: Optional[ActionNoise] = None, - replay_buffer_class: Optional[Type[ReplayBuffer]] = None, - replay_buffer_kwargs: Optional[Dict[str, Any]] = None, - ent_coef: Union[str, float] = "auto", - use_sde: bool = False, - sde_sample_freq: int = -1, - use_sde_at_warmup: bool = False, - tensorboard_log: Optional[str] = None, - policy_kwargs: Optional[Dict[str, Any]] = None, - verbose: int = 0, - seed: Optional[int] = None, - device: str = "auto", - _init_setup_model: bool = True, - ) -> None: - super().__init__( - policy=policy, - env=env, - learning_rate=learning_rate, - qf_learning_rate=qf_learning_rate, - buffer_size=buffer_size, - learning_starts=learning_starts, - batch_size=batch_size, - tau=tau, - gamma=gamma, - train_freq=train_freq, - gradient_steps=gradient_steps, - policy_delay=policy_delay, - action_noise=action_noise, - replay_buffer_class=replay_buffer_class, - replay_buffer_kwargs=replay_buffer_kwargs, - use_sde=use_sde, - sde_sample_freq=sde_sample_freq, - use_sde_at_warmup=use_sde_at_warmup, - top_quantiles_to_drop_per_net=top_quantiles_to_drop_per_net, - ent_coef=ent_coef, - policy_kwargs=policy_kwargs, - tensorboard_log=tensorboard_log, - verbose=verbose, - seed=seed, - _init_setup_model=False, - ) - - self.policy_kwargs["dropout_rate"] = dropout_rate - self.policy_kwargs["layer_norm"] = layer_norm - - warnings.warn( - "Using DroQ class directly is deprecated and will be removed in v0.14.0 of SBX. " - "Please use SAC/TQC/CrossQ instead with the DroQ configuration, see https://github.com/araffin/sbx?tab=readme-ov-file#note-about-droq" - ) - - if _init_setup_model: - self._setup_model() diff --git a/sbx/version.txt b/sbx/version.txt index a551051..04a373e 100644 --- a/sbx/version.txt +++ b/sbx/version.txt @@ -1 +1 @@ -0.15.0 +0.16.0 diff --git a/tests/test_run.py b/tests/test_run.py index bd57aa3..18d6dec 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -25,30 +25,31 @@ def check_save_load(model, model_class, tmp_path): def test_droq(tmp_path): - with pytest.warns(UserWarning, match="deprecated"): - model = DroQ( - "MlpPolicy", - "Pendulum-v1", - learning_starts=50, - learning_rate=1e-3, - tau=0.02, - gamma=0.98, - verbose=1, - buffer_size=5000, - gradient_steps=2, - ent_coef="auto_1.0", - seed=1, - dropout_rate=0.001, - layer_norm=True, - # action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)), - ) + with pytest.raises(ImportError, match="a special configuration of SAC"): + model = DroQ("MlpPolicy", "Pendulum-v1", learning_starts=50) + + # DroQ used to be a child class of TQC, now it can be used with SAC/CrossQ/TQC + model = TQC( + "MlpPolicy", + "Pendulum-v1", + learning_starts=50, + learning_rate=1e-3, + tau=0.02, + gamma=0.98, + verbose=1, + buffer_size=5000, + gradient_steps=2, + ent_coef="auto_1.0", + seed=1, + policy_kwargs=dict(dropout_rate=0.01, layer_norm=True), + ) model.learn(total_timesteps=1500) # Check that something was learned evaluate_policy(model, model.get_env(), reward_threshold=-800) model.save(tmp_path / "test_save.zip") env = model.get_env() - model = check_save_load(model, DroQ, tmp_path) + model = check_save_load(model, TQC, tmp_path) # Check we have the same performance evaluate_policy(model, env, reward_threshold=-800)