Skip to content

Commit

Permalink
Add support for gymnasium v1.0 (#261)
Browse files Browse the repository at this point in the history
* Add support for gymnasium v1.0

* Fix for gym v1.0

* Update CI matrix

* Update SB3 min version

* Fix warning
  • Loading branch information
araffin authored Nov 4, 2024
1 parent e05ee42 commit 9856423
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 16 deletions.
20 changes: 12 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -36,19 +41,18 @@ jobs:
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.4.1+cpu --index https://download.pytorch.org/whl/cpu
# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz
# Install master version
# and dependencies for docs and tests
uv pip install --system "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
uv pip install --system "stable_baselines3[extra,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
uv pip install --system .
# Use headless version
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1

- name: Lint with ruff
run: |
make lint
Expand Down
2 changes: 1 addition & 1 deletion docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium>=0.28.1,<0.30
- gymnasium>=0.29.1,<1.1.0
- stable-baselines3>=2.0.0,<3.0
- cloudpickle
- opencv-python-headless
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a10 (WIP)
Release 2.4.0a11 (WIP)
--------------------------

**New algorithm: added CrossQ**
Expand All @@ -16,6 +16,7 @@ New Features:
^^^^^^^^^^^^^
- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen)
- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen)
- Added support for Gymnasium v1.0

Bug Fixes:
^^^^^^^^^^
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ env = ["PYTHONHASHSEED=0"]
filterwarnings = [
# Tensorboard warnings
"ignore::DeprecationWarning:tensorboard",
# tqdm warning about rich being experimental
"ignore:rich is experimental",
]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]

Expand Down
8 changes: 6 additions & 2 deletions sb3_contrib/common/maskable/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def get_action_masks(env: GymEnv) -> np.ndarray:
if isinstance(env, VecEnv):
return np.stack(env.env_method(EXPECTED_METHOD_NAME))
else:
return getattr(env, EXPECTED_METHOD_NAME)()
return env.get_wrapper_attr(EXPECTED_METHOD_NAME)()


def is_masking_supported(env: GymEnv) -> bool:
Expand All @@ -35,4 +35,8 @@ def is_masking_supported(env: GymEnv) -> bool:
except AttributeError:
return False
else:
return hasattr(env, EXPECTED_METHOD_NAME)
try:
env.get_wrapper_attr(EXPECTED_METHOD_NAME)
return True
except AttributeError:
return False
1 change: 1 addition & 0 deletions sb3_contrib/common/wrappers/time_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, env: gym.Env, max_steps: int = 1000, test_mode: bool = False)
low, high = obs_space.low, obs_space.high
low, high = np.concatenate((low, [0.0])), np.concatenate((high, [1.0])) # type: ignore[arg-type]
self.dtype = obs_space.dtype
low, high = low.astype(self.dtype), high.astype(self.dtype)

if isinstance(env.observation_space, spaces.Dict):
env.observation_space.spaces["observation"] = spaces.Box(
Expand Down
2 changes: 1 addition & 1 deletion sb3_contrib/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a10
2.4.0a11
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
packages=[package for package in find_packages() if package.startswith("sb3_contrib")],
package_data={"sb3_contrib": ["py.typed", "version.txt"]},
install_requires=[
"stable_baselines3>=2.4.0a6,<3.0",
"stable_baselines3>=2.4.0a11,<3.0",
],
description="Contrib package of Stable Baselines3, experimental code.",
author="Antonin Raffin",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
from gymnasium.wrappers.time_limit import TimeLimit
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -43,7 +43,7 @@ def __init__(self):
self.x_threshold * 2,
self.theta_threshold_radians * 2,
]
)
).astype(np.float32)
self.observation_space = spaces.Box(-high, high, dtype=np.float32)

@staticmethod
Expand Down

0 comments on commit 9856423

Please sign in to comment.