From f1df16f3da72b8490bc4039d26f7dbbc66d0dde4 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 15:59:54 +0000 Subject: [PATCH 01/14] add bsuite --- setup.py | 1 + shimmy/__init__.py | 7 +++ shimmy/bsuite_compatibility.py | 93 +++++++++++++++++++++++++++++++++ tests/test_bsuite.py | 94 ++++++++++++++++++++++++++++++++++ 4 files changed, 195 insertions(+) create mode 100644 shimmy/bsuite_compatibility.py create mode 100644 tests/test_bsuite.py diff --git a/setup.py b/setup.py index b1c75da2..fdf3a679 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def get_version(): "dm-control": ["dm-control>=1.0.10", "imageio", "h5py>=3.7.0"], "dm-control-multi-agent": ["dm-control>=1.0.10", "pettingzoo>=1.22"], "openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"], + "bsuite": ["bsuite>=0.3.5"], } extras["all"] = list({lib for libs in extras.values() for lib in libs}) extras["testing"] = [ diff --git a/shimmy/__init__.py b/shimmy/__init__.py index e07296e8..0829c51c 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -47,6 +47,13 @@ def __call__(self, *args: list[Any], **kwargs: Any): e, ) +try: + from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 +except ImportError as e: + BSuiteCompatibilityV0 = NotInstallClass( + "BSuite is not installed, run `pip install 'shimmy[bsuite]'`", + e, + ) __all__ = [ "DmControlCompatibilityV0", diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py new file mode 100644 index 00000000..4dd44af2 --- /dev/null +++ b/shimmy/bsuite_compatibility.py @@ -0,0 +1,93 @@ +"""Wrapper to convert a BSuite environment into a gymnasium compatible environment. +""" +from __future__ import annotations + +from typing import Any + +import gymnasium +import numpy as np +from gymnasium.core import ObsType + +from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space + +from bsuite.environments import Environment + + +class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): + """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. + + Note: + Bsuite uses `np.random.RandomState`, a legacy random number generator while gymnasium + uses `np.random.Generator`, therefore the return type of `np_random` is different from expected. + """ + + metadata = {"render_modes": []} + + def __init__( + self, + env: Environment, + render_mode: str | None = None, + ): + """Initialises the environment with a render mode along with render information.""" + self._env = env + + self.observation_space = dm_spec2gym_space(env.observation_spec()) + self.action_space = dm_spec2gym_space(env.action_spec()) + + assert render_mode is None, f"No render modes available in BSuite." + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[ObsType, dict[str, Any]]: + """Resets the dm-control environment.""" + super().reset(seed=seed) + if seed is not None: + self.np_random = np.random.RandomState(seed=seed) + self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues] + + timestep = self._env.reset() + + obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + + return obs, info # pyright: ignore[reportGeneralTypeIssues] + + def step( + self, action: int + ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + """Steps through the dm-control environment.""" + timestep = self._env.step(action) + + obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + + return ( # pyright: ignore[reportGeneralTypeIssues] + obs, + reward, + terminated, + truncated, + info, + ) + + def render(self) -> np.ndarray | None: + """Renders the dm-control env.""" + raise AssertionError("Rendering is not built into BSuite, print the observation instead.") + + def close(self): + """Closes the environment.""" + + self._env.close() + + if hasattr(self, "viewer"): + self.viewer.close() + + @property + def np_random(self) -> np.random.RandomState: + """This should be np.random.Generator but dm-control uses np.random.RandomState.""" + return self._env._rng # pyright: ignore[reportGeneralTypeIssues] + + @np_random.setter + def np_random(self, value: np.random.RandomState): + self._env._rng = value # pyright: ignore[reportGeneralTypeIssues] + + def __getattr__(self, item: str): + """If the attribute is missing, try getting the attribute from dm_control env.""" + return getattr(self._env, item) diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py new file mode 100644 index 00000000..01d581b1 --- /dev/null +++ b/tests/test_bsuite.py @@ -0,0 +1,94 @@ +"""Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs.""" +import warnings + +import pytest +from gymnasium.error import Error +from gymnasium.utils.env_checker import check_env, data_equivalence + +import bsuite +from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 + +BSUITE_NAME_TO_LOADERS = bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT +BSUITE_ENV_SETTINGS = dict() +BSUITE_ENV_SETTINGS["bandit"] = dict() +BSUITE_ENV_SETTINGS["bandit_noise"] = dict(noise_scale=1, seed=42, mapping_seed=42) +BSUITE_ENV_SETTINGS["bandit_scale"] = dict(reward_scale=1, seed=42, mapping_seed=42) +BSUITE_ENV_SETTINGS["cartpole"] = dict() +BSUITE_ENV_SETTINGS["cartpole_noise"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["cartpole_scale"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["cartpole_swingup"] = dict() +BSUITE_ENV_SETTINGS["catch"] = dict() +BSUITE_ENV_SETTINGS["catch_noise"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["catch_scale"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["deep_sea"] = dict(size=42) +BSUITE_ENV_SETTINGS["deep_sea_stochastic"] = dict(size=42) +BSUITE_ENV_SETTINGS["discounting_chain"] = dict() +BSUITE_ENV_SETTINGS["memory_len"] = dict(memory_length=8) +BSUITE_ENV_SETTINGS["memory_size"] = dict(num_bits=8) +BSUITE_ENV_SETTINGS["mnist"] = dict() +BSUITE_ENV_SETTINGS["mnist_noise"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["mnist_scale"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["mountain_car"] = dict() +BSUITE_ENV_SETTINGS["mountain_car_noise"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["mountain_car_scale"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["umbrella_distract"] = dict(n_distractor=3) +BSUITE_ENV_SETTINGS["umbrella_length"] = dict(chain_length=3) + +# todo - gymnasium v27 should remove the need for some of these warnings +CHECK_ENV_IGNORE_WARNINGS = [ + f"\x1b[33mWARN: {message}\x1b[0m" + for message in [ + "A Box observation space minimum value is -infinity. This is probably too low.", + "A Box observation space maximum value is -infinity. This is probably too high.", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (28, 28)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (42, 42)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (10, 5)", + ] +] + + +@pytest.mark.parametrize("env_id", BSUITE_NAME_TO_LOADERS) +def test_check_env(env_id): + """Check that environment pass the gymnasium check_env.""" + env = bsuite.load(env_id, BSUITE_ENV_SETTINGS[env_id]) + env = BSuiteCompatibilityV0(env) + + with warnings.catch_warnings(record=True) as caught_warnings: + check_env(env.unwrapped) + + for warning_message in caught_warnings: + assert isinstance(warning_message.message, Warning) + if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS: + raise Error(f"Unexpected warning: {warning_message.message}") + + env.close() + + +@pytest.mark.parametrize("env_id", BSUITE_NAME_TO_LOADERS) +def test_seeding(env_id): + """Test that dm-control seeding works.""" + + # bandit and deep_sea and SOMETIMES discounting_chain fail this test + if env_id in ["bandit", "deep_sea", "discounting_chain"]: + return + + env_1 = bsuite.load(env_id, BSUITE_ENV_SETTINGS[env_id]) + env_1 = BSuiteCompatibilityV0(env_1) + env_2 = bsuite.load(env_id, BSUITE_ENV_SETTINGS[env_id]) + env_2 = BSuiteCompatibilityV0(env_2) + + obs_1, info_1 = env_1.reset(seed=42) + obs_2, info_2 = env_2.reset(seed=42) + assert data_equivalence(obs_1, obs_2) + assert data_equivalence(info_1, info_2) + for _ in range(100): + actions = int(env_1.action_space.sample()) + obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions) + obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions) + assert data_equivalence(obs_1, obs_2) + assert reward_1 == reward_2 + assert term_1 == term_2 and trunc_1 == trunc_2 + assert data_equivalence(info_1, info_2) + + env_1.close() + env_2.close() From 2da87a0eb749409383833e1e53f65aad42f95ad4 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 16:01:38 +0000 Subject: [PATCH 02/14] fix docstrings and rename function --- shimmy/bsuite_compatibility.py | 16 ++++++++-------- shimmy/dm_control_compatibility.py | 6 +++--- shimmy/utils/dm_env.py | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index 4dd44af2..4a8c316c 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -8,7 +8,7 @@ import numpy as np from gymnasium.core import ObsType -from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space +from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space from bsuite.environments import Environment @@ -39,7 +39,7 @@ def __init__( def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[ObsType, dict[str, Any]]: - """Resets the dm-control environment.""" + """Resets the bsuite environment.""" super().reset(seed=seed) if seed is not None: self.np_random = np.random.RandomState(seed=seed) @@ -47,17 +47,17 @@ def reset( timestep = self._env.reset() - obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) return obs, info # pyright: ignore[reportGeneralTypeIssues] def step( self, action: int ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: - """Steps through the dm-control environment.""" + """Steps through the bsuite environment.""" timestep = self._env.step(action) - obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) return ( # pyright: ignore[reportGeneralTypeIssues] obs, @@ -68,7 +68,7 @@ def step( ) def render(self) -> np.ndarray | None: - """Renders the dm-control env.""" + """Renders the bsuite env.""" raise AssertionError("Rendering is not built into BSuite, print the observation instead.") def close(self): @@ -81,7 +81,7 @@ def close(self): @property def np_random(self) -> np.random.RandomState: - """This should be np.random.Generator but dm-control uses np.random.RandomState.""" + """This should be np.random.Generator but bsuite uses np.random.RandomState.""" return self._env._rng # pyright: ignore[reportGeneralTypeIssues] @np_random.setter @@ -89,5 +89,5 @@ def np_random(self, value: np.random.RandomState): self._env._rng = value # pyright: ignore[reportGeneralTypeIssues] def __getattr__(self, item: str): - """If the attribute is missing, try getting the attribute from dm_control env.""" + """If the attribute is missing, try getting the attribute from bsuite env.""" return getattr(self._env, item) diff --git a/shimmy/dm_control_compatibility.py b/shimmy/dm_control_compatibility.py index 7e344fb1..314a1d26 100644 --- a/shimmy/dm_control_compatibility.py +++ b/shimmy/dm_control_compatibility.py @@ -17,7 +17,7 @@ from gymnasium.core import ObsType from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer -from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space +from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space class EnvType(Enum): @@ -84,7 +84,7 @@ def reset( timestep = self._env.reset() - obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) return obs, info @@ -94,7 +94,7 @@ def step( """Steps through the dm-control environment.""" timestep = self._env.step(action) - obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep) + obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep) if self.render_mode == "human": self.viewer.render(self.render_mode) diff --git a/shimmy/utils/dm_env.py b/shimmy/utils/dm_env.py index 25b2df34..316ae8a8 100644 --- a/shimmy/utils/dm_env.py +++ b/shimmy/utils/dm_env.py @@ -51,7 +51,7 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]: return np.asarray(obs) -def dm_control_step2gym_step( +def dm_env_step2gym_step( timestep, ) -> tuple[Any, float, bool, bool, dict[str, Any]]: """Opens up the timestep to return obs, reward, terminated, truncated, info.""" From 9489ece6adb7c3c9417864aac66b06e1a11f2d48 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 16:12:33 +0000 Subject: [PATCH 03/14] fix precommit --- shimmy/bsuite_compatibility.py | 26 ++++++++++---------------- tests/test_bsuite.py | 3 +-- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index 4a8c316c..5adf26c2 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -1,17 +1,15 @@ -"""Wrapper to convert a BSuite environment into a gymnasium compatible environment. -""" +"""Wrapper to convert a BSuite environment into a gymnasium compatible environment.""" from __future__ import annotations from typing import Any import gymnasium import numpy as np +from bsuite.environments import Environment from gymnasium.core import ObsType from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space -from bsuite.environments import Environment - class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. @@ -34,7 +32,7 @@ def __init__( self.observation_space = dm_spec2gym_space(env.observation_spec()) self.action_space = dm_spec2gym_space(env.action_spec()) - assert render_mode is None, f"No render modes available in BSuite." + assert render_mode is None, "No render modes available in BSuite." def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None @@ -43,7 +41,7 @@ def reset( super().reset(seed=seed) if seed is not None: self.np_random = np.random.RandomState(seed=seed) - self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues] + self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues] timestep = self._env.reset() @@ -51,9 +49,7 @@ def reset( return obs, info # pyright: ignore[reportGeneralTypeIssues] - def step( - self, action: int - ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Steps through the bsuite environment.""" timestep = self._env.step(action) @@ -69,24 +65,22 @@ def step( def render(self) -> np.ndarray | None: """Renders the bsuite env.""" - raise AssertionError("Rendering is not built into BSuite, print the observation instead.") + raise AssertionError( + "Rendering is not built into BSuite, print the observation instead." + ) def close(self): """Closes the environment.""" - self._env.close() - if hasattr(self, "viewer"): - self.viewer.close() - @property def np_random(self) -> np.random.RandomState: """This should be np.random.Generator but bsuite uses np.random.RandomState.""" - return self._env._rng # pyright: ignore[reportGeneralTypeIssues] + return self._env._rng # pyright: ignore[reportGeneralTypeIssues] @np_random.setter def np_random(self, value: np.random.RandomState): - self._env._rng = value # pyright: ignore[reportGeneralTypeIssues] + self._env._rng = value # pyright: ignore[reportGeneralTypeIssues] def __getattr__(self, item: str): """If the attribute is missing, try getting the attribute from bsuite env.""" diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index 01d581b1..590b38c9 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -1,11 +1,11 @@ """Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs.""" import warnings +import bsuite import pytest from gymnasium.error import Error from gymnasium.utils.env_checker import check_env, data_equivalence -import bsuite from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 BSUITE_NAME_TO_LOADERS = bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT @@ -67,7 +67,6 @@ def test_check_env(env_id): @pytest.mark.parametrize("env_id", BSUITE_NAME_TO_LOADERS) def test_seeding(env_id): """Test that dm-control seeding works.""" - # bandit and deep_sea and SOMETIMES discounting_chain fail this test if env_id in ["bandit", "deep_sea", "discounting_chain"]: return From 7512cb1d48104eaf8a813058a0e4b21f5f018a4c Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 16:43:53 +0000 Subject: [PATCH 04/14] add farama notifications --- shimmy/__init__.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/shimmy/__init__.py b/shimmy/__init__.py index 0829c51c..1e56d4c4 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -6,8 +6,6 @@ from shimmy.dm_lab_compatibility import DmLabCompatibilityV0 from shimmy.openai_gym_compatibility import GymV21CompatibilityV0, GymV26CompatibilityV0 -__version__ = "0.2.1" - class NotInstallClass: """Rather than an attribute error, this raises a more helpful import error with install instructions for shimmy.""" @@ -63,3 +61,16 @@ def __call__(self, *args: list[Any], **kwargs: Any): "GymV21CompatibilityV0", "GymV26CompatibilityV0", ] + + +__version__ = "0.2.1" + + +try: + import sys + from farama_notifications import notifications + + if "shimmy" in notifications and __version__ in notifications["shimmy"]: + print(notifications["shimmy"][__version__], file=sys.stderr) +except Exception: # nosec + pass From f7431b48ec9dc97d30b09ce244aff2c96f180ac7 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 18:40:38 +0000 Subject: [PATCH 05/14] add unsupportedmode --- shimmy/__init__.py | 1 + shimmy/bsuite_compatibility.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/shimmy/__init__.py b/shimmy/__init__.py index 1e56d4c4..2cd20a0c 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -54,6 +54,7 @@ def __call__(self, *args: list[Any], **kwargs: Any): ) __all__ = [ + "BSuiteCompatibilityV0", "DmControlCompatibilityV0", "DmControlMultiAgentCompatibilityV0", "OpenspielCompatibilityV0", diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index 5adf26c2..2e44bd3d 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -4,6 +4,7 @@ from typing import Any import gymnasium +from gymnasium.error import UnsupportedMode import numpy as np from bsuite.environments import Environment from gymnasium.core import ObsType @@ -65,7 +66,7 @@ def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]] def render(self) -> np.ndarray | None: """Renders the bsuite env.""" - raise AssertionError( + raise UnsupportedMode( "Rendering is not built into BSuite, print the observation instead." ) From dbe9199e1075de25157f9f1a18d8f2ac4a026712 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 19:47:36 +0000 Subject: [PATCH 06/14] register things --- shimmy/registration.py | 35 +++++++++++++++++ shimmy/utils/envs_configs.py | 26 +++++++++++++ tests/test_bsuite.py | 74 ++++++++++++++++++++---------------- 3 files changed, 102 insertions(+), 33 deletions(-) diff --git a/shimmy/registration.py b/shimmy/registration.py index 1bcf5128..2b76baff 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -10,12 +10,46 @@ from shimmy.utils.envs_configs import ( ALL_ATARI_GAMES, + BSUITE_ENVS, DM_CONTROL_MANIPULATION_ENVS, DM_CONTROL_SUITE_ENVS, LEGACY_ATARI_GAMES, ) +def _register_bsuite_envs(): + """Registers all bsuite environments in gymnasium.""" + try: + import bsuite + except ImportError: + return + + from bsuite.environments import Environment + from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 + + # Add generic environment support + def _make_bsuite_generic_env(env: Environment, render_mode: str): + return BSuiteCompatibilityV0(env, render_mode=render_mode) + + register("bsuite/compatibility-env-v0", _make_bsuite_generic_env) + + # register all prebuilt envs + def _make_bsuite_env( + env_id: str, **env_kwargs: Mapping[str, Any] + ): + env = bsuite.load(env_id, env_kwargs) + return BSuiteCompatibilityV0(env) + + for env_id in BSUITE_ENVS: + register( + f"bsuite/{env_id}-v0", + partial( + _make_bsuite_env, + env_id=env_id + ), + ) + + def _register_dm_control_envs(): """Registers all dm-control environments in gymnasium.""" try: @@ -259,6 +293,7 @@ def register_gymnasium_envs(): entry_point="shimmy.openai_gym_compatibility:GymV21CompatibilityV0", ) + _register_bsuite_envs() _register_dm_control_envs() _register_atari_envs() _register_dm_lab() diff --git a/shimmy/utils/envs_configs.py b/shimmy/utils/envs_configs.py index 8036d286..76ecf217 100644 --- a/shimmy/utils/envs_configs.py +++ b/shimmy/utils/envs_configs.py @@ -1,5 +1,31 @@ """Environment configures.""" +BSUITE_ENVS = ( + "bandit", + "bandit_noise", + "bandit_scale", + "cartpole", + "cartpole_noise", + "cartpole_scale", + "cartpole_swingup", + "catch", + "catch_noise", + "catch_scale", + "deep_sea", + "deep_sea_stochastic", + "discounting_chain", + "memory_len", + "memory_size", + "mnist", + "mnist_noise", + "mnist_scale", + "mountain_car", + "mountain_car_noise", + "mountain_car_scale", + "umbrella_distract", + "umbrella_length", +) + DM_CONTROL_SUITE_ENVS = ( ("acrobot", "swingup"), ("acrobot", "swingup_sparse"), diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index 590b38c9..19d67e44 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -3,36 +3,47 @@ import bsuite import pytest +import gymnasium as gym from gymnasium.error import Error from gymnasium.utils.env_checker import check_env, data_equivalence -from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 +from gymnasium.envs.registration import registry + +BSUITE_ENV_IDS = [ + env_id + for env_id in registry + if env_id.startswith("bsuite") and env_id != "bsuite/compatibility-env-v0" +] + +def test_bsuite_suite_envs(): + """Tests that all BSUITE_ENVS are equal to the known bsuite tasks.""" + env_ids = [env_id.split("/")[-1].split("-")[0] for env_id in BSUITE_ENV_IDS] + assert list(bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT.keys()) == env_ids -BSUITE_NAME_TO_LOADERS = bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT BSUITE_ENV_SETTINGS = dict() -BSUITE_ENV_SETTINGS["bandit"] = dict() -BSUITE_ENV_SETTINGS["bandit_noise"] = dict(noise_scale=1, seed=42, mapping_seed=42) -BSUITE_ENV_SETTINGS["bandit_scale"] = dict(reward_scale=1, seed=42, mapping_seed=42) -BSUITE_ENV_SETTINGS["cartpole"] = dict() -BSUITE_ENV_SETTINGS["cartpole_noise"] = dict(noise_scale=1, seed=42) -BSUITE_ENV_SETTINGS["cartpole_scale"] = dict(reward_scale=1, seed=42) -BSUITE_ENV_SETTINGS["cartpole_swingup"] = dict() -BSUITE_ENV_SETTINGS["catch"] = dict() -BSUITE_ENV_SETTINGS["catch_noise"] = dict(noise_scale=1, seed=42) -BSUITE_ENV_SETTINGS["catch_scale"] = dict(reward_scale=1, seed=42) -BSUITE_ENV_SETTINGS["deep_sea"] = dict(size=42) -BSUITE_ENV_SETTINGS["deep_sea_stochastic"] = dict(size=42) -BSUITE_ENV_SETTINGS["discounting_chain"] = dict() -BSUITE_ENV_SETTINGS["memory_len"] = dict(memory_length=8) -BSUITE_ENV_SETTINGS["memory_size"] = dict(num_bits=8) -BSUITE_ENV_SETTINGS["mnist"] = dict() -BSUITE_ENV_SETTINGS["mnist_noise"] = dict(noise_scale=1, seed=42) -BSUITE_ENV_SETTINGS["mnist_scale"] = dict(reward_scale=1, seed=42) -BSUITE_ENV_SETTINGS["mountain_car"] = dict() -BSUITE_ENV_SETTINGS["mountain_car_noise"] = dict(noise_scale=1, seed=42) -BSUITE_ENV_SETTINGS["mountain_car_scale"] = dict(reward_scale=1, seed=42) -BSUITE_ENV_SETTINGS["umbrella_distract"] = dict(n_distractor=3) -BSUITE_ENV_SETTINGS["umbrella_length"] = dict(chain_length=3) +BSUITE_ENV_SETTINGS["bsuite/bandit-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/bandit_noise-v0"] = dict(noise_scale=1, seed=42, mapping_seed=42) +BSUITE_ENV_SETTINGS["bsuite/bandit_scale-v0"] = dict(reward_scale=1, seed=42, mapping_seed=42) +BSUITE_ENV_SETTINGS["bsuite/cartpole-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/cartpole_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/cartpole_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/cartpole_swingup-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/catch-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/catch_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/catch_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/deep_sea-v0"] = dict(size=42) +BSUITE_ENV_SETTINGS["bsuite/deep_sea_stochastic-v0"] = dict(size=42) +BSUITE_ENV_SETTINGS["bsuite/discounting_chain-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/memory_len-v0"] = dict(memory_length=8) +BSUITE_ENV_SETTINGS["bsuite/memory_size-v0"] = dict(num_bits=8) +BSUITE_ENV_SETTINGS["bsuite/mnist-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/mnist_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/mnist_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/mountain_car-v0"] = dict() +BSUITE_ENV_SETTINGS["bsuite/mountain_car_noise-v0"] = dict(noise_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/mountain_car_scale-v0"] = dict(reward_scale=1, seed=42) +BSUITE_ENV_SETTINGS["bsuite/umbrella_distract-v0"] = dict(n_distractor=3) +BSUITE_ENV_SETTINGS["bsuite/umbrella_length-v0"] = dict(chain_length=3) # todo - gymnasium v27 should remove the need for some of these warnings CHECK_ENV_IGNORE_WARNINGS = [ @@ -47,11 +58,10 @@ ] -@pytest.mark.parametrize("env_id", BSUITE_NAME_TO_LOADERS) +@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS) def test_check_env(env_id): """Check that environment pass the gymnasium check_env.""" - env = bsuite.load(env_id, BSUITE_ENV_SETTINGS[env_id]) - env = BSuiteCompatibilityV0(env) + env = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) with warnings.catch_warnings(record=True) as caught_warnings: check_env(env.unwrapped) @@ -64,17 +74,15 @@ def test_check_env(env_id): env.close() -@pytest.mark.parametrize("env_id", BSUITE_NAME_TO_LOADERS) +@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS) def test_seeding(env_id): """Test that dm-control seeding works.""" # bandit and deep_sea and SOMETIMES discounting_chain fail this test if env_id in ["bandit", "deep_sea", "discounting_chain"]: return - env_1 = bsuite.load(env_id, BSUITE_ENV_SETTINGS[env_id]) - env_1 = BSuiteCompatibilityV0(env_1) - env_2 = bsuite.load(env_id, BSUITE_ENV_SETTINGS[env_id]) - env_2 = BSuiteCompatibilityV0(env_2) + env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) + env_2 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) obs_1, info_1 = env_1.reset(seed=42) obs_2, info_2 = env_2.reset(seed=42) From 5012302b0486f123524ce08fd44dde687de26d48 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 27 Feb 2023 19:48:05 +0000 Subject: [PATCH 07/14] precommit --- shimmy/__init__.py | 1 + shimmy/bsuite_compatibility.py | 2 +- shimmy/registration.py | 10 +++------- tests/test_bsuite.py | 15 ++++++++++----- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/shimmy/__init__.py b/shimmy/__init__.py index 2cd20a0c..01c415c1 100644 --- a/shimmy/__init__.py +++ b/shimmy/__init__.py @@ -69,6 +69,7 @@ def __call__(self, *args: list[Any], **kwargs: Any): try: import sys + from farama_notifications import notifications if "shimmy" in notifications and __version__ in notifications["shimmy"]: diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index 2e44bd3d..2fa86d61 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -4,10 +4,10 @@ from typing import Any import gymnasium -from gymnasium.error import UnsupportedMode import numpy as np from bsuite.environments import Environment from gymnasium.core import ObsType +from gymnasium.error import UnsupportedMode from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space diff --git a/shimmy/registration.py b/shimmy/registration.py index 2b76baff..49839e77 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -25,6 +25,7 @@ def _register_bsuite_envs(): return from bsuite.environments import Environment + from shimmy.bsuite_compatibility import BSuiteCompatibilityV0 # Add generic environment support @@ -34,19 +35,14 @@ def _make_bsuite_generic_env(env: Environment, render_mode: str): register("bsuite/compatibility-env-v0", _make_bsuite_generic_env) # register all prebuilt envs - def _make_bsuite_env( - env_id: str, **env_kwargs: Mapping[str, Any] - ): + def _make_bsuite_env(env_id: str, **env_kwargs: Mapping[str, Any]): env = bsuite.load(env_id, env_kwargs) return BSuiteCompatibilityV0(env) for env_id in BSUITE_ENVS: register( f"bsuite/{env_id}-v0", - partial( - _make_bsuite_env, - env_id=env_id - ), + partial(_make_bsuite_env, env_id=env_id), ) diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index 19d67e44..a2b24f01 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -2,28 +2,33 @@ import warnings import bsuite -import pytest import gymnasium as gym +import pytest +from gymnasium.envs.registration import registry from gymnasium.error import Error from gymnasium.utils.env_checker import check_env, data_equivalence -from gymnasium.envs.registration import registry - BSUITE_ENV_IDS = [ env_id for env_id in registry if env_id.startswith("bsuite") and env_id != "bsuite/compatibility-env-v0" ] + def test_bsuite_suite_envs(): """Tests that all BSUITE_ENVS are equal to the known bsuite tasks.""" env_ids = [env_id.split("/")[-1].split("-")[0] for env_id in BSUITE_ENV_IDS] assert list(bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT.keys()) == env_ids + BSUITE_ENV_SETTINGS = dict() BSUITE_ENV_SETTINGS["bsuite/bandit-v0"] = dict() -BSUITE_ENV_SETTINGS["bsuite/bandit_noise-v0"] = dict(noise_scale=1, seed=42, mapping_seed=42) -BSUITE_ENV_SETTINGS["bsuite/bandit_scale-v0"] = dict(reward_scale=1, seed=42, mapping_seed=42) +BSUITE_ENV_SETTINGS["bsuite/bandit_noise-v0"] = dict( + noise_scale=1, seed=42, mapping_seed=42 +) +BSUITE_ENV_SETTINGS["bsuite/bandit_scale-v0"] = dict( + reward_scale=1, seed=42, mapping_seed=42 +) BSUITE_ENV_SETTINGS["bsuite/cartpole-v0"] = dict() BSUITE_ENV_SETTINGS["bsuite/cartpole_noise-v0"] = dict(noise_scale=1, seed=42) BSUITE_ENV_SETTINGS["bsuite/cartpole_scale-v0"] = dict(reward_scale=1, seed=42) From ff71b1f8d5ee6e2793a05dee146172d97218616b Mon Sep 17 00:00:00 2001 From: snow-fox Date: Tue, 28 Feb 2023 16:25:22 +0000 Subject: [PATCH 08/14] fix tests --- shimmy/bsuite_compatibility.py | 2 ++ shimmy/registration.py | 4 ++++ tests/test_bsuite.py | 10 ++++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index 2fa86d61..e7775c0d 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -43,6 +43,8 @@ def reset( if seed is not None: self.np_random = np.random.RandomState(seed=seed) self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues] + if hasattr(self._env, "raw_env"): + self._env.raw_env._rng = self.np_random timestep = self._env.reset() diff --git a/shimmy/registration.py b/shimmy/registration.py index 49839e77..7d6fc4a1 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -39,10 +39,14 @@ def _make_bsuite_env(env_id: str, **env_kwargs: Mapping[str, Any]): env = bsuite.load(env_id, env_kwargs) return BSuiteCompatibilityV0(env) + # non deterministic envs + nondeterministic = ["deep_sea", "bandit"] + for env_id in BSUITE_ENVS: register( f"bsuite/{env_id}-v0", partial(_make_bsuite_env, env_id=env_id), + nondeterministic=env_id in nondeterministic, ) diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index a2b24f01..2ddcfe78 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -59,6 +59,12 @@ def test_bsuite_suite_envs(): "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (28, 28)", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (42, 42)", "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (10, 5)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 1)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 2)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 3)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 6)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 8)", + "A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 10)", ] ] @@ -82,8 +88,8 @@ def test_check_env(env_id): @pytest.mark.parametrize("env_id", BSUITE_ENV_IDS) def test_seeding(env_id): """Test that dm-control seeding works.""" - # bandit and deep_sea and SOMETIMES discounting_chain fail this test - if env_id in ["bandit", "deep_sea", "discounting_chain"]: + # bandit and deep_sea fail this test + if env_id in ["bsuite/deep_sea-v0", "bsuite/bandit-v0"]: return env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) From 7e25751bfd4a781d1da37360ac889cdf5e29ad49 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 6 Mar 2023 16:39:08 +0000 Subject: [PATCH 09/14] cursed thing --- shimmy/bsuite_compatibility.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index e7775c0d..c05afb77 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -11,6 +11,10 @@ from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space +# Until the BSuite authors fix +# https://github.com/deepmind/bsuite/pull/48 +# This needs to exist... +np.int = int class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. From 3fb00d6ca4b139b9c33aca5f195ac4e9d684f7a7 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 6 Mar 2023 16:57:27 +0000 Subject: [PATCH 10/14] ignore discounting chain --- shimmy/bsuite_compatibility.py | 1 + tests/test_bsuite.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/shimmy/bsuite_compatibility.py b/shimmy/bsuite_compatibility.py index c05afb77..39db8d05 100644 --- a/shimmy/bsuite_compatibility.py +++ b/shimmy/bsuite_compatibility.py @@ -16,6 +16,7 @@ # This needs to exist... np.int = int + class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]): """A compatibility wrapper that converts a BSuite environment into a gymnasium environment. diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index 2ddcfe78..6edf4a15 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -89,7 +89,7 @@ def test_check_env(env_id): def test_seeding(env_id): """Test that dm-control seeding works.""" # bandit and deep_sea fail this test - if env_id in ["bsuite/deep_sea-v0", "bsuite/bandit-v0"]: + if env_id in ["bsuite/deep_sea-v0", "bsuite/bandit-v0", "bsuite/discount_chain-v0"]: return env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id]) From aacd55eb71aefca53688b0aded09d2ccf25bf194 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 6 Mar 2023 16:58:34 +0000 Subject: [PATCH 11/14] nondeterministic discounting_chain --- shimmy/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shimmy/registration.py b/shimmy/registration.py index 7d6fc4a1..f853b0b4 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -40,7 +40,7 @@ def _make_bsuite_env(env_id: str, **env_kwargs: Mapping[str, Any]): return BSuiteCompatibilityV0(env) # non deterministic envs - nondeterministic = ["deep_sea", "bandit"] + nondeterministic = ["deep_sea", "bandit", "discounting_chain"] for env_id in BSUITE_ENVS: register( From 5ff73052621a04974710c8a03dc079f3feeb228c Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 6 Mar 2023 17:12:12 +0000 Subject: [PATCH 12/14] change citation --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0f7f0826..c80b8265 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ env = OpenspielCompatibilityV0(game=env, render_mode=None) If you use this in your research, please cite: ``` @software{shimmy2022github, - author = {Jordan Terry, Mark Towers, Jun Jet Tai}, + author = {Jun Jet Tai, Mark Towers, Elliot Tower, Jordan Terry}, title = {Shimmy: Gymnasium and Pettingzoo Wrappers for Commonly Used Environments}, url = {http://github.com/Farama-Foundation/Shimmy}, version = {0.2.0}, From 5a541b1433ac8448ebdd366b1cd327c9f2be12c1 Mon Sep 17 00:00:00 2001 From: jjshoots Date: Mon, 6 Mar 2023 17:21:48 +0000 Subject: [PATCH 13/14] equal contribution --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c80b8265..6df05313 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ env = OpenspielCompatibilityV0(game=env, render_mode=None) If you use this in your research, please cite: ``` @software{shimmy2022github, - author = {Jun Jet Tai, Mark Towers, Elliot Tower, Jordan Terry}, + author = {{Jun Jet Tai, Mark Towers} and Elliot Tower and Jordan Terry}, title = {Shimmy: Gymnasium and Pettingzoo Wrappers for Commonly Used Environments}, url = {http://github.com/Farama-Foundation/Shimmy}, version = {0.2.0}, From 0483c36f7fcacf7d307da482234842577a3c990d Mon Sep 17 00:00:00 2001 From: jjshoots Date: Tue, 7 Mar 2023 11:52:22 +0000 Subject: [PATCH 14/14] fix bsuite tests --- tests/test_bsuite.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_bsuite.py b/tests/test_bsuite.py index 6edf4a15..820827fd 100644 --- a/tests/test_bsuite.py +++ b/tests/test_bsuite.py @@ -88,8 +88,7 @@ def test_check_env(env_id): @pytest.mark.parametrize("env_id", BSUITE_ENV_IDS) def test_seeding(env_id): """Test that dm-control seeding works.""" - # bandit and deep_sea fail this test - if env_id in ["bsuite/deep_sea-v0", "bsuite/bandit-v0", "bsuite/discount_chain-v0"]: + if gym.spec(env_id).nondeterministic: return env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])