Skip to content

Commit 1f96c02

Browse files
authored
add bsuite (#38)
1 parent 6e529b6 commit 1f96c02

9 files changed

+295
-7
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ env = OpenspielCompatibilityV0(game=env, render_mode=None)
212212
If you use this in your research, please cite:
213213
```
214214
@software{shimmy2022github,
215-
author = {Jordan Terry, Mark Towers, Jun Jet Tai},
215+
author = {{Jun Jet Tai, Mark Towers} and Elliot Tower and Jordan Terry},
216216
title = {Shimmy: Gymnasium and Pettingzoo Wrappers for Commonly Used Environments},
217217
url = {http://github.com/Farama-Foundation/Shimmy},
218218
version = {0.2.0},

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def get_version():
3939
"dm-control": ["dm-control>=1.0.10", "imageio", "h5py>=3.7.0"],
4040
"dm-control-multi-agent": ["dm-control>=1.0.10", "pettingzoo>=1.22"],
4141
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"],
42+
"bsuite": ["bsuite>=0.3.5"],
4243
}
4344
extras["all"] = list({lib for libs in extras.values() for lib in libs})
4445
extras["testing"] = [

shimmy/__init__.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from shimmy.dm_lab_compatibility import DmLabCompatibilityV0
77
from shimmy.openai_gym_compatibility import GymV21CompatibilityV0, GymV26CompatibilityV0
88

9-
__version__ = "0.2.1"
10-
119

1210
class NotInstallClass:
1311
"""Rather than an attribute error, this raises a more helpful import error with install instructions for shimmy."""
@@ -47,12 +45,34 @@ def __call__(self, *args: list[Any], **kwargs: Any):
4745
e,
4846
)
4947

48+
try:
49+
from shimmy.bsuite_compatibility import BSuiteCompatibilityV0
50+
except ImportError as e:
51+
BSuiteCompatibilityV0 = NotInstallClass(
52+
"BSuite is not installed, run `pip install 'shimmy[bsuite]'`",
53+
e,
54+
)
5055

5156
__all__ = [
57+
"BSuiteCompatibilityV0",
5258
"DmControlCompatibilityV0",
5359
"DmControlMultiAgentCompatibilityV0",
5460
"OpenspielCompatibilityV0",
5561
"DmLabCompatibilityV0",
5662
"GymV21CompatibilityV0",
5763
"GymV26CompatibilityV0",
5864
]
65+
66+
67+
__version__ = "0.2.1"
68+
69+
70+
try:
71+
import sys
72+
73+
from farama_notifications import notifications
74+
75+
if "shimmy" in notifications and __version__ in notifications["shimmy"]:
76+
print(notifications["shimmy"][__version__], file=sys.stderr)
77+
except Exception: # nosec
78+
pass

shimmy/bsuite_compatibility.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Wrapper to convert a BSuite environment into a gymnasium compatible environment."""
2+
from __future__ import annotations
3+
4+
from typing import Any
5+
6+
import gymnasium
7+
import numpy as np
8+
from bsuite.environments import Environment
9+
from gymnasium.core import ObsType
10+
from gymnasium.error import UnsupportedMode
11+
12+
from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space
13+
14+
# Until the BSuite authors fix
15+
# https://github.com/deepmind/bsuite/pull/48
16+
# This needs to exist...
17+
np.int = int
18+
19+
20+
class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
21+
"""A compatibility wrapper that converts a BSuite environment into a gymnasium environment.
22+
23+
Note:
24+
Bsuite uses `np.random.RandomState`, a legacy random number generator while gymnasium
25+
uses `np.random.Generator`, therefore the return type of `np_random` is different from expected.
26+
"""
27+
28+
metadata = {"render_modes": []}
29+
30+
def __init__(
31+
self,
32+
env: Environment,
33+
render_mode: str | None = None,
34+
):
35+
"""Initialises the environment with a render mode along with render information."""
36+
self._env = env
37+
38+
self.observation_space = dm_spec2gym_space(env.observation_spec())
39+
self.action_space = dm_spec2gym_space(env.action_spec())
40+
41+
assert render_mode is None, "No render modes available in BSuite."
42+
43+
def reset(
44+
self, *, seed: int | None = None, options: dict[str, Any] | None = None
45+
) -> tuple[ObsType, dict[str, Any]]:
46+
"""Resets the bsuite environment."""
47+
super().reset(seed=seed)
48+
if seed is not None:
49+
self.np_random = np.random.RandomState(seed=seed)
50+
self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues]
51+
if hasattr(self._env, "raw_env"):
52+
self._env.raw_env._rng = self.np_random
53+
54+
timestep = self._env.reset()
55+
56+
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
57+
58+
return obs, info # pyright: ignore[reportGeneralTypeIssues]
59+
60+
def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
61+
"""Steps through the bsuite environment."""
62+
timestep = self._env.step(action)
63+
64+
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
65+
66+
return ( # pyright: ignore[reportGeneralTypeIssues]
67+
obs,
68+
reward,
69+
terminated,
70+
truncated,
71+
info,
72+
)
73+
74+
def render(self) -> np.ndarray | None:
75+
"""Renders the bsuite env."""
76+
raise UnsupportedMode(
77+
"Rendering is not built into BSuite, print the observation instead."
78+
)
79+
80+
def close(self):
81+
"""Closes the environment."""
82+
self._env.close()
83+
84+
@property
85+
def np_random(self) -> np.random.RandomState:
86+
"""This should be np.random.Generator but bsuite uses np.random.RandomState."""
87+
return self._env._rng # pyright: ignore[reportGeneralTypeIssues]
88+
89+
@np_random.setter
90+
def np_random(self, value: np.random.RandomState):
91+
self._env._rng = value # pyright: ignore[reportGeneralTypeIssues]
92+
93+
def __getattr__(self, item: str):
94+
"""If the attribute is missing, try getting the attribute from bsuite env."""
95+
return getattr(self._env, item)

shimmy/dm_control_compatibility.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from gymnasium.core import ObsType
1818
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
1919

20-
from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space
20+
from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space
2121

2222

2323
class EnvType(Enum):
@@ -84,7 +84,7 @@ def reset(
8484

8585
timestep = self._env.reset()
8686

87-
obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)
87+
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
8888

8989
return obs, info
9090

@@ -94,7 +94,7 @@ def step(
9494
"""Steps through the dm-control environment."""
9595
timestep = self._env.step(action)
9696

97-
obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)
97+
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
9898

9999
if self.render_mode == "human":
100100
self.viewer.render(self.render_mode)

shimmy/registration.py

+35
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,46 @@
1010

1111
from shimmy.utils.envs_configs import (
1212
ALL_ATARI_GAMES,
13+
BSUITE_ENVS,
1314
DM_CONTROL_MANIPULATION_ENVS,
1415
DM_CONTROL_SUITE_ENVS,
1516
LEGACY_ATARI_GAMES,
1617
)
1718

1819

20+
def _register_bsuite_envs():
21+
"""Registers all bsuite environments in gymnasium."""
22+
try:
23+
import bsuite
24+
except ImportError:
25+
return
26+
27+
from bsuite.environments import Environment
28+
29+
from shimmy.bsuite_compatibility import BSuiteCompatibilityV0
30+
31+
# Add generic environment support
32+
def _make_bsuite_generic_env(env: Environment, render_mode: str):
33+
return BSuiteCompatibilityV0(env, render_mode=render_mode)
34+
35+
register("bsuite/compatibility-env-v0", _make_bsuite_generic_env)
36+
37+
# register all prebuilt envs
38+
def _make_bsuite_env(env_id: str, **env_kwargs: Mapping[str, Any]):
39+
env = bsuite.load(env_id, env_kwargs)
40+
return BSuiteCompatibilityV0(env)
41+
42+
# non deterministic envs
43+
nondeterministic = ["deep_sea", "bandit", "discounting_chain"]
44+
45+
for env_id in BSUITE_ENVS:
46+
register(
47+
f"bsuite/{env_id}-v0",
48+
partial(_make_bsuite_env, env_id=env_id),
49+
nondeterministic=env_id in nondeterministic,
50+
)
51+
52+
1953
def _register_dm_control_envs():
2054
"""Registers all dm-control environments in gymnasium."""
2155
try:
@@ -259,6 +293,7 @@ def register_gymnasium_envs():
259293
entry_point="shimmy.openai_gym_compatibility:GymV21CompatibilityV0",
260294
)
261295

296+
_register_bsuite_envs()
262297
_register_dm_control_envs()
263298
_register_atari_envs()
264299
_register_dm_lab()

shimmy/utils/dm_env.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]:
5151
return np.asarray(obs)
5252

5353

54-
def dm_control_step2gym_step(
54+
def dm_env_step2gym_step(
5555
timestep,
5656
) -> tuple[Any, float, bool, bool, dict[str, Any]]:
5757
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""

shimmy/utils/envs_configs.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
11
"""Environment configures."""
22

3+
BSUITE_ENVS = (
4+
"bandit",
5+
"bandit_noise",
6+
"bandit_scale",
7+
"cartpole",
8+
"cartpole_noise",
9+
"cartpole_scale",
10+
"cartpole_swingup",
11+
"catch",
12+
"catch_noise",
13+
"catch_scale",
14+
"deep_sea",
15+
"deep_sea_stochastic",
16+
"discounting_chain",
17+
"memory_len",
18+
"memory_size",
19+
"mnist",
20+
"mnist_noise",
21+
"mnist_scale",
22+
"mountain_car",
23+
"mountain_car_noise",
24+
"mountain_car_scale",
25+
"umbrella_distract",
26+
"umbrella_length",
27+
)
28+
329
DM_CONTROL_SUITE_ENVS = (
430
("acrobot", "swingup"),
531
("acrobot", "swingup_sparse"),

tests/test_bsuite.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
"""Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs."""
2+
import warnings
3+
4+
import bsuite
5+
import gymnasium as gym
6+
import pytest
7+
from gymnasium.envs.registration import registry
8+
from gymnasium.error import Error
9+
from gymnasium.utils.env_checker import check_env, data_equivalence
10+
11+
BSUITE_ENV_IDS = [
12+
env_id
13+
for env_id in registry
14+
if env_id.startswith("bsuite") and env_id != "bsuite/compatibility-env-v0"
15+
]
16+
17+
18+
def test_bsuite_suite_envs():
19+
"""Tests that all BSUITE_ENVS are equal to the known bsuite tasks."""
20+
env_ids = [env_id.split("/")[-1].split("-")[0] for env_id in BSUITE_ENV_IDS]
21+
assert list(bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT.keys()) == env_ids
22+
23+
24+
BSUITE_ENV_SETTINGS = dict()
25+
BSUITE_ENV_SETTINGS["bsuite/bandit-v0"] = dict()
26+
BSUITE_ENV_SETTINGS["bsuite/bandit_noise-v0"] = dict(
27+
noise_scale=1, seed=42, mapping_seed=42
28+
)
29+
BSUITE_ENV_SETTINGS["bsuite/bandit_scale-v0"] = dict(
30+
reward_scale=1, seed=42, mapping_seed=42
31+
)
32+
BSUITE_ENV_SETTINGS["bsuite/cartpole-v0"] = dict()
33+
BSUITE_ENV_SETTINGS["bsuite/cartpole_noise-v0"] = dict(noise_scale=1, seed=42)
34+
BSUITE_ENV_SETTINGS["bsuite/cartpole_scale-v0"] = dict(reward_scale=1, seed=42)
35+
BSUITE_ENV_SETTINGS["bsuite/cartpole_swingup-v0"] = dict()
36+
BSUITE_ENV_SETTINGS["bsuite/catch-v0"] = dict()
37+
BSUITE_ENV_SETTINGS["bsuite/catch_noise-v0"] = dict(noise_scale=1, seed=42)
38+
BSUITE_ENV_SETTINGS["bsuite/catch_scale-v0"] = dict(reward_scale=1, seed=42)
39+
BSUITE_ENV_SETTINGS["bsuite/deep_sea-v0"] = dict(size=42)
40+
BSUITE_ENV_SETTINGS["bsuite/deep_sea_stochastic-v0"] = dict(size=42)
41+
BSUITE_ENV_SETTINGS["bsuite/discounting_chain-v0"] = dict()
42+
BSUITE_ENV_SETTINGS["bsuite/memory_len-v0"] = dict(memory_length=8)
43+
BSUITE_ENV_SETTINGS["bsuite/memory_size-v0"] = dict(num_bits=8)
44+
BSUITE_ENV_SETTINGS["bsuite/mnist-v0"] = dict()
45+
BSUITE_ENV_SETTINGS["bsuite/mnist_noise-v0"] = dict(noise_scale=1, seed=42)
46+
BSUITE_ENV_SETTINGS["bsuite/mnist_scale-v0"] = dict(reward_scale=1, seed=42)
47+
BSUITE_ENV_SETTINGS["bsuite/mountain_car-v0"] = dict()
48+
BSUITE_ENV_SETTINGS["bsuite/mountain_car_noise-v0"] = dict(noise_scale=1, seed=42)
49+
BSUITE_ENV_SETTINGS["bsuite/mountain_car_scale-v0"] = dict(reward_scale=1, seed=42)
50+
BSUITE_ENV_SETTINGS["bsuite/umbrella_distract-v0"] = dict(n_distractor=3)
51+
BSUITE_ENV_SETTINGS["bsuite/umbrella_length-v0"] = dict(chain_length=3)
52+
53+
# todo - gymnasium v27 should remove the need for some of these warnings
54+
CHECK_ENV_IGNORE_WARNINGS = [
55+
f"\x1b[33mWARN: {message}\x1b[0m"
56+
for message in [
57+
"A Box observation space minimum value is -infinity. This is probably too low.",
58+
"A Box observation space maximum value is -infinity. This is probably too high.",
59+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (28, 28)",
60+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (42, 42)",
61+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (10, 5)",
62+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 1)",
63+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 2)",
64+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 3)",
65+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 6)",
66+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 8)",
67+
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 10)",
68+
]
69+
]
70+
71+
72+
@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS)
73+
def test_check_env(env_id):
74+
"""Check that environment pass the gymnasium check_env."""
75+
env = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
76+
77+
with warnings.catch_warnings(record=True) as caught_warnings:
78+
check_env(env.unwrapped)
79+
80+
for warning_message in caught_warnings:
81+
assert isinstance(warning_message.message, Warning)
82+
if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
83+
raise Error(f"Unexpected warning: {warning_message.message}")
84+
85+
env.close()
86+
87+
88+
@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS)
89+
def test_seeding(env_id):
90+
"""Test that dm-control seeding works."""
91+
if gym.spec(env_id).nondeterministic:
92+
return
93+
94+
env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
95+
env_2 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
96+
97+
obs_1, info_1 = env_1.reset(seed=42)
98+
obs_2, info_2 = env_2.reset(seed=42)
99+
assert data_equivalence(obs_1, obs_2)
100+
assert data_equivalence(info_1, info_2)
101+
for _ in range(100):
102+
actions = int(env_1.action_space.sample())
103+
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
104+
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
105+
assert data_equivalence(obs_1, obs_2)
106+
assert reward_1 == reward_2
107+
assert term_1 == term_2 and trunc_1 == trunc_2
108+
assert data_equivalence(info_1, info_2)
109+
110+
env_1.close()
111+
env_2.close()

0 commit comments

Comments
 (0)