-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathbsuite_compatibility.py
95 lines (72 loc) · 3.23 KB
/
bsuite_compatibility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""Wrapper to convert a BSuite environment into a gymnasium compatible environment."""
from __future__ import annotations
from typing import Any
import gymnasium
import numpy as np
from bsuite.environments import Environment
from gymnasium.core import ObsType
from gymnasium.error import UnsupportedMode
from shimmy.utils.dm_env import dm_env_step2gym_step, dm_spec2gym_space
# Until the BSuite authors fix
# https://github.com/deepmind/bsuite/pull/48
# This needs to exist...
np.int = int
class BSuiteCompatibilityV0(gymnasium.Env[ObsType, np.ndarray]):
"""A compatibility wrapper that converts a BSuite environment into a gymnasium environment.
Note:
Bsuite uses `np.random.RandomState`, a legacy random number generator while gymnasium
uses `np.random.Generator`, therefore the return type of `np_random` is different from expected.
"""
metadata = {"render_modes": []}
def __init__(
self,
env: Environment,
render_mode: str | None = None,
):
"""Initialises the environment with a render mode along with render information."""
self._env = env
self.observation_space = dm_spec2gym_space(env.observation_spec())
self.action_space = dm_spec2gym_space(env.action_spec())
assert render_mode is None, "No render modes available in BSuite."
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[ObsType, dict[str, Any]]:
"""Resets the bsuite environment."""
super().reset(seed=seed)
if seed is not None:
self.np_random = np.random.RandomState(seed=seed)
self._env._rng = self.np_random # pyright: ignore[reportGeneralTypeIssues]
if hasattr(self._env, "raw_env"):
self._env.raw_env._rng = self.np_random
timestep = self._env.reset()
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
return obs, info # pyright: ignore[reportGeneralTypeIssues]
def step(self, action: int) -> tuple[ObsType, float, bool, bool, dict[str, Any]]:
"""Steps through the bsuite environment."""
timestep = self._env.step(action)
obs, reward, terminated, truncated, info = dm_env_step2gym_step(timestep)
return ( # pyright: ignore[reportGeneralTypeIssues]
obs,
reward,
terminated,
truncated,
info,
)
def render(self) -> np.ndarray | None:
"""Renders the bsuite env."""
raise UnsupportedMode(
"Rendering is not built into BSuite, print the observation instead."
)
def close(self):
"""Closes the environment."""
self._env.close()
@property
def np_random(self) -> np.random.RandomState:
"""This should be np.random.Generator but bsuite uses np.random.RandomState."""
return self._env._rng # pyright: ignore[reportGeneralTypeIssues]
@np_random.setter
def np_random(self, value: np.random.RandomState):
self._env._rng = value # pyright: ignore[reportGeneralTypeIssues]
def __getattr__(self, item: str):
"""If the attribute is missing, try getting the attribute from bsuite env."""
return getattr(self._env, item)