Skip to content

Commit

Permalink
Update SB3 and remove gSDE resampling (#251)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Jun 29, 2024
1 parent 25b4326 commit dc25cc6
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ lint:
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
Expand Down
26 changes: 26 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,32 @@
Changelog
==========


Release 2.4.0a4 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Upgraded to Stable-Baselines3 >= 2.4.0

New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Updated PyTorch version on CI to 2.3.1
- Remove unnecessary SDE noise resampling in PPO/TRPO update

Documentation:
^^^^^^^^^^^^^^


Release 2.3.0 (2024-03-31)
--------------------------

Expand Down
4 changes: 0 additions & 4 deletions sb3_contrib/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,6 @@ def train(self) -> None:
# Convert mask from float to bool
mask = rollout_data.mask > 1e-8

# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)

values, log_prob, entropy = self.policy.evaluate_actions(
rollout_data.observations,
actions,
Expand Down
5 changes: 0 additions & 5 deletions sb3_contrib/trpo/trpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,6 @@ def train(self) -> None:
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()

# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
# batch_size is only used for the value function
self.policy.reset_noise(actions.shape[0])

with th.no_grad():
# Note: is copy enough, no need for deepcopy?
# If using gSDE and deepcopy, we need to use `old_distribution.distribution`
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.3.0
2.4.0a4
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.3.0,<3.0",
"stable_baselines3>=2.4.0a4,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down

0 comments on commit dc25cc6

Please sign in to comment.