diff --git a/README.md b/README.md index 0f7f0826..6df05313 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ env = OpenspielCompatibilityV0(game=env, render_mode=None) If you use this in your research, please cite: ``` @software{shimmy2022github, - author = {Jordan Terry, Mark Towers, Jun Jet Tai}, + author = {{Jun Jet Tai, Mark Towers} and Elliot Tower and Jordan Terry}, title = {Shimmy: Gymnasium and Pettingzoo Wrappers for Commonly Used Environments}, url = {http://github.com/Farama-Foundation/Shimmy}, version = {0.2.0}, diff --git a/setup.py b/setup.py index b1c75da2..fdf3a679 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def get_version(): "dm-control": ["dm-control>=1.0.10", "imageio", "h5py>=3.7.0"], "dm-control-multi-agent": ["dm-control>=1.0.10", "pettingzoo>=1.22"], "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"], + "bsuite": ["bsuite>=0.3.5"], } extras["all"] = list({lib for libs in extras.values() for lib in libs}) extras["testing"] = [ diff --git a/shimmy/__init__.py b/shimmy/__init__.py index e07296e8..01c415c1 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -6,8 +6,6 @@ from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 from shimmy.openai_gym_compatibility import GymV21CompatibilityV0, GymV26CompatibilityV0 -__version__ = "0.2.1" - class NotInstallClass: """Rather than an attribute error, this raises a more helpful import error with install instructions for shimmy.""" @@ -47,8 +45,16 @@ def __call__(self, *args: list[Any], **kwargs: Any): e, ) +try: + from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 +except ImportError as e: + BSuiteCompatibilityV0 = NotInstallClass( + "BSuite is not installed, run `pip install 'shimmy[bsuite]'`", + e, + ) __all__ = [ + "BSuiteCompatibilityV0", "DmControlCompatibilityV0", "DmControlMultiAgentCompatibilityV0", "OpenspielCompatibilityV0", @@ -56,3 +62,17 @@ def __call__(self, *args: list[Any], **kwargs: Any): "GymV21CompatibilityV0", "GymV26CompatibilityV0", ] + + +__version__ = "0.2.1" + + +try: + import sys + + from farama_notifications import notifications + + if "shimmy" in notifications and __version__ in notifications["shimmy"]: + print(notifications["shimmy"][__version__], file=sys.stderr) +except Exception: # nosec + pass diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py new file mode 100644 index 00000000..39db8d05 --- /dev/null +++ b/shimmy/bsuite_compatibility.py @@ -0,0 +1,95 @@ +"""Wrapper to convert a BSuite environment into a gymnasium compatible environment.""" +from __future__ import annotations + +from typing import Any + +import gymnasium +import numpy as np +from bsuite.environments import Environment +from gymnasium.core import ObsType +from gymnasium.error import UnsupportedMode + +from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space + +# Until the BSuite authors fix +# https://github.com/deepmind/bsuite/pull/48 +# This needs to exist... +np.int = int + + +class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): + """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. + + Note: + Bsuite uses `np.random.RandomState`, a legacy random number generator while gymnasium + uses `np.random.Generator`, therefore the return type of `np_random` is different from expected. + """ + + metadata = {"render_modes": []} + + def __init__( + self, + env: Environment, + render_mode: str | None = None, + ): + """Initialises the environment with a render mode along with render information.""" + self._env = env + + self.observation_space = dm_spec2gym_space(env.observation_spec()) + self.action_space = dm_spec2gym_space(env.action_spec()) + + assert render_mode is None, "No render modes available in BSuite." + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the bsuite environment.""" + super().reset(seed=seed) + if seed is not None: + self.np_random = np.random.RandomState(seed=seed) + self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues] + if hasattr(self._env, "raw_env"): + self._env.raw_env._rng = self.np_random + + timestep = self._env.reset() + + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) + + return obs, info # pyright: ignore[reportGeneralTypeIssues] + + def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + """Steps through the bsuite environment.""" + timestep = self._env.step(action) + + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) + + return ( # pyright: ignore[reportGeneralTypeIssues] + obs, + reward, + terminated, + truncated, + info, + ) + + def render(self) -> np.ndarray | None: + """Renders the bsuite env.""" + raise UnsupportedMode( + "Rendering is not built into BSuite, print the observation instead." + ) + + def close(self): + """Closes the environment.""" + self._env.close() + + @property + def np_random(self) -> np.random.RandomState: + """This should be np.random.Generator but bsuite uses np.random.RandomState.""" + return self._env._rng # pyright: ignore[reportGeneralTypeIssues] + + @np_random.setter + def np_random(self, value: np.random.RandomState): + self._env._rng = value # pyright: ignore[reportGeneralTypeIssues] + + def __getattr__(self, item: str): + """If the attribute is missing, try getting the attribute from bsuite env.""" + return getattr(self._env, item) diff --git a/shimmy/dm_control_compatibility.py b/shimmy/dm_control_compatibility.py index 7e344fb1..314a1d26 100644 --- a/shimmy/dm_control_compatibility.py +++ b/shimmy/dm_control_compatibility.py @@ -17,7 +17,7 @@ from gymnasium.core import ObsType from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer -from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space +from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space class EnvType(Enum): @@ -84,7 +84,7 @@ def reset( timestep = self._env.reset() - obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) return obs, info @@ -94,7 +94,7 @@ def step( """Steps through the dm-control environment.""" timestep = self._env.step(action) - obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) if self.render_mode == "human": self.viewer.render(self.render_mode) diff --git a/shimmy/registration.py b/shimmy/registration.py index 1bcf5128..f853b0b4 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -10,12 +10,46 @@ from shimmy.utils.envs_configs import ( ALL_ATARI_GAMES, + BSUITE_ENVS, DM_CONTROL_MANIPULATION_ENVS, DM_CONTROL_SUITE_ENVS, LEGACY_ATARI_GAMES, ) +def _register_bsuite_envs(): + """Registers all bsuite environments in gymnasium.""" + try: + import bsuite + except ImportError: + return + + from bsuite.environments import Environment + + from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 + + # Add generic environment support + def _make_bsuite_generic_env(env: Environment, render_mode: str): + return BSuiteCompatibilityV0(env, render_mode=render_mode) + + register("bsuite/compatibility-env-v0", _make_bsuite_generic_env) + + # register all prebuilt envs + def _make_bsuite_env(env_id: str, **env_kwargs: Mapping[str, Any]): + env = bsuite.load(env_id, env_kwargs) + return BSuiteCompatibilityV0(env) + + # non deterministic envs + nondeterministic = ["deep_sea", "bandit", "discounting_chain"] + + for env_id in BSUITE_ENVS: + register( + f"bsuite/{env_id}-v0", + partial(_make_bsuite_env, env_id=env_id), + nondeterministic=env_id in nondeterministic, + ) + + def _register_dm_control_envs(): """Registers all dm-control environments in gymnasium.""" try: @@ -259,6 +293,7 @@ def register_gymnasium_envs(): entry_point="shimmy.openai_gym_compatibility:GymV21CompatibilityV0", ) + _register_bsuite_envs() _register_dm_control_envs() _register_atari_envs() _register_dm_lab() diff --git a/shimmy/utils/dm_env.py b/shimmy/utils/dm_env.py index 25b2df34..316ae8a8 100644 --- a/shimmy/utils/dm_env.py +++ b/shimmy/utils/dm_env.py @@ -51,7 +51,7 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]: return np.asarray(obs) -def dm_control_step2gym_step( +def dm_env_step2gym_step( timestep, ) -> tuple[Any, float, bool, bool, dict[str, Any]]: """Opens up the timestep to return obs, reward, terminated, truncated, info.""" diff --git a/shimmy/utils/envs_configs.py b/shimmy/utils/envs_configs.py index 8036d286..76ecf217 100644 --- a/shimmy/utils/envs_configs.py +++ b/shimmy/utils/envs_configs.py @@ -1,5 +1,31 @@ """Environment configures.""" +BSUITE_ENVS = ( + "bandit", + "bandit_noise", + "bandit_scale", + "cartpole", + "cartpole_noise", + "cartpole_scale", + "cartpole_swingup", + "catch", + "catch_noise", + "catch_scale", + "deep_sea", + "deep_sea_stochastic", + "discounting_chain", + "memory_len", + "memory_size", + "mnist", + "mnist_noise", + "mnist_scale", + "mountain_car", + "mountain_car_noise", + "mountain_car_scale", + "umbrella_distract", + "umbrella_length", +) + DM_CONTROL_SUITE_ENVS = ( ("acrobot", "swingup"), ("acrobot", "swingup_sparse"), diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py new file mode 100644 index 00000000..820827fd --- /dev/null +++ b/tests/test_bsuite.py @@ -0,0 +1,111 @@ +"""Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs.""" +import warnings + +import bsuite +import gymnasium as gym +import pytest +from gymnasium.envs.registration import registry +from gymnasium.error import Error +from gymnasium.utils.env_checker import check_env, data_equivalence + +BSUITE_ENV_IDS = [ + env_id + for env_id in registry + if env_id.startswith("bsuite") and env_id != "bsuite/compatibility-env-v0" +] + + +def test_bsuite_suite_envs(): + """Tests that all BSUITE_ENVS are equal to the known bsuite tasks.""" + env_ids = [env_id.split("/")[-1].split("-")[0] for env_id in BSUITE_ENV_IDS] + assert list(bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT.keys()) == env_ids + + +BSUITE_ENV_SETTINGS = dict() +BSUITE_ENV_SETTINGS["bsuite/bandit-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/bandit_noise-v0"] = dict( + noise_scale=1, seed=42, mapping_seed=42 +) +BSUITE_ENV_SETTINGS["bsuite/bandit_scale-v0"] = dict( + reward_scale=1, seed=42, mapping_seed=42 +) +BSUITE_ENV_SETTINGS["bsuite/cartpole-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/cartpole_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/cartpole_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/cartpole_swingup-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/catch-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/catch_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/catch_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/deep_sea-v0"] = dict(size=42) +BSUITE_ENV_SETTINGS["bsuite/deep_sea_stochastic-v0"] = dict(size=42) +BSUITE_ENV_SETTINGS["bsuite/discounting_chain-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/memory_len-v0"] = dict(memory_length=8) +BSUITE_ENV_SETTINGS["bsuite/memory_size-v0"] = dict(num_bits=8) +BSUITE_ENV_SETTINGS["bsuite/mnist-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/mnist_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/mnist_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/mountain_car-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/mountain_car_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/mountain_car_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/umbrella_distract-v0"] = dict(n_distractor=3) +BSUITE_ENV_SETTINGS["bsuite/umbrella_length-v0"] = dict(chain_length=3) + +# todo - gymnasium v27 should remove the need for some of these warnings +CHECK_ENV_IGNORE_WARNINGS = [ + f"\x1b[33mWARN: {message}\x1b[0m" + for message in [ + "A Box observation space minimum value is -infinity. This is probably too low.", + "A Box observation space maximum value is -infinity. This is probably too high.", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (28, 28)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (42, 42)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (10, 5)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 1)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 2)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 3)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 6)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 8)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 10)", + ] +] + + +@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS) +def test_check_env(env_id): + """Check that environment pass the gymnasium check_env.""" + env = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) + + with warnings.catch_warnings(record=True) as caught_warnings: + check_env(env.unwrapped) + + for warning_message in caught_warnings: + assert isinstance(warning_message.message, Warning) + if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: + raise Error(f"Unexpected warning: {warning_message.message}") + + env.close() + + +@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS) +def test_seeding(env_id): + """Test that dm-control seeding works.""" + if gym.spec(env_id).nondeterministic: + return + + env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) + env_2 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) + + obs_1, info_1 = env_1.reset(seed=42) + obs_2, info_2 = env_2.reset(seed=42) + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = int(env_1.action_space.sample()) + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close()