Skip to content

Commit

Permalink
Fix warning and remove DroQ class in favor of SAC config (#47)
Browse files Browse the repository at this point in the history
* Fix numpy warning when computing terminal value

* Remove deprecated DroQ, use special config of SAC instead
  • Loading branch information
araffin authored May 23, 2024
1 parent 42caa65 commit 27de67c
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 112 deletions.
11 changes: 9 additions & 2 deletions sbx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from sbx.crossq import CrossQ
from sbx.ddpg import DDPG
from sbx.dqn import DQN
from sbx.droq import DroQ
from sbx.ppo import PPO
from sbx.sac import SAC
from sbx.td3 import TD3
Expand All @@ -14,11 +13,19 @@
with open(version_file) as file_handler:
__version__ = file_handler.read().strip()


def DroQ(*args, **kwargs):
raise ImportError(
"Since SBX 0.16.0, `DroQ` is now a special configuration of SAC.\n "
"Please check the documentation for more information: "
"https://github.com/araffin/sbx?tab=readme-ov-file#note-about-droq"
)


__all__ = [
"CrossQ",
"DDPG",
"DQN",
"DroQ",
"PPO",
"SAC",
"TD3",
Expand Down
3 changes: 1 addition & 2 deletions sbx/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,7 @@ def collect_rollouts(
self.policy.vf_state.params,
terminal_obs,
).flatten()
)

).item()
rewards[idx] += self.gamma * terminal_value

rollout_buffer.add(
Expand Down
3 changes: 0 additions & 3 deletions sbx/droq/__init__.py

This file was deleted.

86 changes: 0 additions & 86 deletions sbx/droq/droq.py

This file was deleted.

2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.15.0
0.16.0
37 changes: 19 additions & 18 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,30 +25,31 @@ def check_save_load(model, model_class, tmp_path):


def test_droq(tmp_path):
with pytest.warns(UserWarning, match="deprecated"):
model = DroQ(
"MlpPolicy",
"Pendulum-v1",
learning_starts=50,
learning_rate=1e-3,
tau=0.02,
gamma=0.98,
verbose=1,
buffer_size=5000,
gradient_steps=2,
ent_coef="auto_1.0",
seed=1,
dropout_rate=0.001,
layer_norm=True,
# action_noise=NormalActionNoise(np.zeros(1), np.zeros(1)),
)
with pytest.raises(ImportError, match="a special configuration of SAC"):
model = DroQ("MlpPolicy", "Pendulum-v1", learning_starts=50)

# DroQ used to be a child class of TQC, now it can be used with SAC/CrossQ/TQC
model = TQC(
"MlpPolicy",
"Pendulum-v1",
learning_starts=50,
learning_rate=1e-3,
tau=0.02,
gamma=0.98,
verbose=1,
buffer_size=5000,
gradient_steps=2,
ent_coef="auto_1.0",
seed=1,
policy_kwargs=dict(dropout_rate=0.01, layer_norm=True),
)
model.learn(total_timesteps=1500)
# Check that something was learned
evaluate_policy(model, model.get_env(), reward_threshold=-800)
model.save(tmp_path / "test_save.zip")

env = model.get_env()
model = check_save_load(model, DroQ, tmp_path)
model = check_save_load(model, TQC, tmp_path)
# Check we have the same performance
evaluate_policy(model, env, reward_threshold=-800)

Expand Down

0 comments on commit 27de67c

Please sign in to comment.