Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add bsuite #38

Merged
merged 14 commits into from
Mar 7, 2023
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ env = OpenspielCompatibilityV0(game=env, render_mode=None)
If you use this in your research, please cite:
```
@software{shimmy2022github,
author = {Jordan Terry, Mark Towers, Jun Jet Tai},
author = {{Jun Jet Tai, Mark Towers} and Elliot Tower and Jordan Terry},
title = {Shimmy: Gymnasium and Pettingzoo Wrappers for Commonly Used Environments},
url = {http://github.com/Farama-Foundation/Shimmy},
version = {0.2.0},
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def get_version():
"dm-control": ["dm-control>=1.0.10", "imageio", "h5py>=3.7.0"],
"dm-control-multi-agent": ["dm-control>=1.0.10", "pettingzoo>=1.22"],
"openspiel": ["open_spiel>=1.2", "pettingzoo>=1.22"],
"bsuite": ["bsuite>=0.3.5"],
}
extras["all"] = list({lib for libs in extras.values() for lib in libs})
extras["testing"] = [
Expand Down
24 changes: 22 additions & 2 deletions shimmy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from shimmy.dm_lab_compatibility import DmLabCompatibilityV0
from shimmy.openai_gym_compatibility import GymV21CompatibilityV0, GymV26CompatibilityV0

__version__ = "0.2.1"


class NotInstallClass:
"""Rather than an attribute error, this raises a more helpful import error with install instructions for shimmy."""
Expand Down Expand Up @@ -47,12 +45,34 @@ def __call__(self, *args: list[Any], **kwargs: Any):
e,
)

try:
from shimmy.bsuite_compatibility import BSuiteCompatibilityV0
except ImportError as e:
BSuiteCompatibilityV0 = NotInstallClass(
"BSuite is not installed, run `pip install 'shimmy[bsuite]'`",
e,
)

__all__ = [
"BSuiteCompatibilityV0",
"DmControlCompatibilityV0",
"DmControlMultiAgentCompatibilityV0",
"OpenspielCompatibilityV0",
"DmLabCompatibilityV0",
"GymV21CompatibilityV0",
"GymV26CompatibilityV0",
]


__version__ = "0.2.1"


try:
import sys

from farama_notifications import notifications

if "shimmy" in notifications and __version__ in notifications["shimmy"]:
print(notifications["shimmy"][__version__], file=sys.stderr)
except Exception: # nosec
pass
95 changes: 95 additions & 0 deletions shimmy/bsuite_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Wrapper to convert a BSuite environment into a gymnasium compatible environment."""
from __future__ import annotations

from typing import Any

import gymnasium
import numpy as np
from bsuite.environments import Environment
from gymnasium.core import ObsType
from gymnasium.error import UnsupportedMode

from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space

# Until the BSuite authors fix
# https://github.com/deepmind/bsuite/pull/48
# This needs to exist...
np.int = int


class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a BSuite environment into a gymnasium environment.

Note:
Bsuite uses `np.random.RandomState`, a legacy random number generator while gymnasium
uses `np.random.Generator`, therefore the return type of `np_random` is different from expected.
"""

metadata = {"render_modes": []}

def __init__(
self,
env: Environment,
render_mode: str | None = None,
):
"""Initialises the environment with a render mode along with render information."""
self._env = env

self.observation_space = dm_spec2gym_space(env.observation_spec())
self.action_space = dm_spec2gym_space(env.action_spec())

assert render_mode is None, "No render modes available in BSuite."

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the bsuite environment."""
super().reset(seed=seed)
if seed is not None:
self.np_random = np.random.RandomState(seed=seed)
self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues]
if hasattr(self._env, "raw_env"):
self._env.raw_env._rng = self.np_random

timestep = self._env.reset()

obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)

return obs, info # pyright: ignore[reportGeneralTypeIssues]

def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the bsuite environment."""
timestep = self._env.step(action)

obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)

return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)

def render(self) -> np.ndarray | None:
"""Renders the bsuite env."""
raise UnsupportedMode(
"Rendering is not built into BSuite, print the observation instead."
)

def close(self):
"""Closes the environment."""
self._env.close()

@property
def np_random(self) -> np.random.RandomState:
"""This should be np.random.Generator but bsuite uses np.random.RandomState."""
return self._env._rng # pyright: ignore[reportGeneralTypeIssues]

@np_random.setter
def np_random(self, value: np.random.RandomState):
self._env._rng = value # pyright: ignore[reportGeneralTypeIssues]

def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from bsuite env."""
return getattr(self._env, item)
6 changes: 3 additions & 3 deletions shimmy/dm_control_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from gymnasium.core import ObsType
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer

from shimmy.utils.dm_env import dm_control_step2gym_step, dm_spec2gym_space
from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space


class EnvType(Enum):
Expand Down Expand Up @@ -84,7 +84,7 @@ def reset(

timestep = self._env.reset()

obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)

return obs, info

Expand All @@ -94,7 +94,7 @@ def step(
"""Steps through the dm-control environment."""
timestep = self._env.step(action)

obs, reward, terminated, truncated, info = dm_control_step2gym_step(timestep)
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)

if self.render_mode == "human":
self.viewer.render(self.render_mode)
Expand Down
35 changes: 35 additions & 0 deletions shimmy/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,46 @@

from shimmy.utils.envs_configs import (
ALL_ATARI_GAMES,
BSUITE_ENVS,
DM_CONTROL_MANIPULATION_ENVS,
DM_CONTROL_SUITE_ENVS,
LEGACY_ATARI_GAMES,
)


def _register_bsuite_envs():
"""Registers all bsuite environments in gymnasium."""
try:
import bsuite
except ImportError:
return

from bsuite.environments import Environment

from shimmy.bsuite_compatibility import BSuiteCompatibilityV0

# Add generic environment support
def _make_bsuite_generic_env(env: Environment, render_mode: str):
return BSuiteCompatibilityV0(env, render_mode=render_mode)

register("bsuite/compatibility-env-v0", _make_bsuite_generic_env)

# register all prebuilt envs
def _make_bsuite_env(env_id: str, **env_kwargs: Mapping[str, Any]):
env = bsuite.load(env_id, env_kwargs)
return BSuiteCompatibilityV0(env)

# non deterministic envs
nondeterministic = ["deep_sea", "bandit", "discounting_chain"]

for env_id in BSUITE_ENVS:
register(
f"bsuite/{env_id}-v0",
partial(_make_bsuite_env, env_id=env_id),
nondeterministic=env_id in nondeterministic,
)


def _register_dm_control_envs():
"""Registers all dm-control environments in gymnasium."""
try:
Expand Down Expand Up @@ -259,6 +293,7 @@ def register_gymnasium_envs():
entry_point="shimmy.openai_gym_compatibility:GymV21CompatibilityV0",
)

_register_bsuite_envs()
_register_dm_control_envs()
_register_atari_envs()
_register_dm_lab()
2 changes: 1 addition & 1 deletion shimmy/utils/dm_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def dm_obs2gym_obs(obs) -> np.ndarray | dict[str, Any]:
return np.asarray(obs)


def dm_control_step2gym_step(
def dm_env_step2gym_step(
timestep,
) -> tuple[Any, float, bool, bool, dict[str, Any]]:
"""Opens up the timestep to return obs, reward, terminated, truncated, info."""
Expand Down
26 changes: 26 additions & 0 deletions shimmy/utils/envs_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,31 @@
"""Environment configures."""

BSUITE_ENVS = (
"bandit",
"bandit_noise",
"bandit_scale",
"cartpole",
"cartpole_noise",
"cartpole_scale",
"cartpole_swingup",
"catch",
"catch_noise",
"catch_scale",
"deep_sea",
"deep_sea_stochastic",
"discounting_chain",
"memory_len",
"memory_size",
"mnist",
"mnist_noise",
"mnist_scale",
"mountain_car",
"mountain_car_noise",
"mountain_car_scale",
"umbrella_distract",
"umbrella_length",
)

DM_CONTROL_SUITE_ENVS = (
("acrobot", "swingup"),
("acrobot", "swingup_sparse"),
Expand Down
111 changes: 111 additions & 0 deletions tests/test_bsuite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Tests the functionality of the BSuiteCompatibilityV0 on bsuite envs."""
import warnings

import bsuite
import gymnasium as gym
import pytest
from gymnasium.envs.registration import registry
from gymnasium.error import Error
from gymnasium.utils.env_checker import check_env, data_equivalence

BSUITE_ENV_IDS = [
env_id
for env_id in registry
if env_id.startswith("bsuite") and env_id != "bsuite/compatibility-env-v0"
]


def test_bsuite_suite_envs():
"""Tests that all BSUITE_ENVS are equal to the known bsuite tasks."""
env_ids = [env_id.split("/")[-1].split("-")[0] for env_id in BSUITE_ENV_IDS]
assert list(bsuite._bsuite.EXPERIMENT_NAME_TO_ENVIRONMENT.keys()) == env_ids


BSUITE_ENV_SETTINGS = dict()
BSUITE_ENV_SETTINGS["bsuite/bandit-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/bandit_noise-v0"] = dict(
noise_scale=1, seed=42, mapping_seed=42
)
BSUITE_ENV_SETTINGS["bsuite/bandit_scale-v0"] = dict(
reward_scale=1, seed=42, mapping_seed=42
)
BSUITE_ENV_SETTINGS["bsuite/cartpole-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/cartpole_noise-v0"] = dict(noise_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/cartpole_scale-v0"] = dict(reward_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/cartpole_swingup-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/catch-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/catch_noise-v0"] = dict(noise_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/catch_scale-v0"] = dict(reward_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/deep_sea-v0"] = dict(size=42)
BSUITE_ENV_SETTINGS["bsuite/deep_sea_stochastic-v0"] = dict(size=42)
BSUITE_ENV_SETTINGS["bsuite/discounting_chain-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/memory_len-v0"] = dict(memory_length=8)
BSUITE_ENV_SETTINGS["bsuite/memory_size-v0"] = dict(num_bits=8)
BSUITE_ENV_SETTINGS["bsuite/mnist-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/mnist_noise-v0"] = dict(noise_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/mnist_scale-v0"] = dict(reward_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/mountain_car-v0"] = dict()
BSUITE_ENV_SETTINGS["bsuite/mountain_car_noise-v0"] = dict(noise_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/mountain_car_scale-v0"] = dict(reward_scale=1, seed=42)
BSUITE_ENV_SETTINGS["bsuite/umbrella_distract-v0"] = dict(n_distractor=3)
BSUITE_ENV_SETTINGS["bsuite/umbrella_length-v0"] = dict(chain_length=3)

# todo - gymnasium v27 should remove the need for some of these warnings
CHECK_ENV_IGNORE_WARNINGS = [
f"\x1b[33mWARN: {message}\x1b[0m"
for message in [
"A Box observation space minimum value is -infinity. This is probably too low.",
"A Box observation space maximum value is -infinity. This is probably too high.",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (28, 28)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (42, 42)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (10, 5)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 1)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 2)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 3)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 6)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 8)",
"A Box observation space has an unconventional shape (neither an image, nor a 1D vector). We recommend flattening the observation to have only a 1D vector or use a custom policy to properly process the data. Actual observation shape: (1, 10)",
]
]


@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS)
def test_check_env(env_id):
"""Check that environment pass the gymnasium check_env."""
env = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])

with warnings.catch_warnings(record=True) as caught_warnings:
check_env(env.unwrapped)

for warning_message in caught_warnings:
assert isinstance(warning_message.message, Warning)
if warning_message.message.args[0] not in CHECK_ENV_IGNORE_WARNINGS:
raise Error(f"Unexpected warning: {warning_message.message}")

env.close()


@pytest.mark.parametrize("env_id", BSUITE_ENV_IDS)
def test_seeding(env_id):
"""Test that dm-control seeding works."""
if gym.spec(env_id).nondeterministic:
return

env_1 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])
env_2 = gym.make(env_id, **BSUITE_ENV_SETTINGS[env_id])

obs_1, info_1 = env_1.reset(seed=42)
obs_2, info_2 = env_2.reset(seed=42)
assert data_equivalence(obs_1, obs_2)
assert data_equivalence(info_1, info_2)
for _ in range(100):
actions = int(env_1.action_space.sample())
obs_1, reward_1, term_1, trunc_1, info_1 = env_1.step(actions)
obs_2, reward_2, term_2, trunc_2, info_2 = env_2.step(actions)
assert data_equivalence(obs_1, obs_2)
assert reward_1 == reward_2
assert term_1 == term_2 and trunc_1 == trunc_2
assert data_equivalence(info_1, info_2)

env_1.close()
env_2.close()