Skip to content

Commit

Permalink
feature(nyz): adapt DingEnvWrapper to gymnasium (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 authored Jul 6, 2024
1 parent 7f95159 commit 96ccaed
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 32 deletions.
60 changes: 31 additions & 29 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import treetensor.numpy as tnp

from ding.envs.common.common_function import affine_transform
from ding.envs.env_wrappers import create_env_wrapper
from ding.envs.env_wrappers import create_env_wrapper, GymToGymnasiumWrapper
from ding.torch_utils import to_ndarray
from ding.utils import CloudPickleWrapper
from .base_env import BaseEnv, BaseEnvTimestep
Expand All @@ -23,7 +23,14 @@ class DingEnvWrapper(BaseEnv):
create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
"""

def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
def __init__(
self,
env: Union[gym.Env, gymnasium.Env] = None,
cfg: dict = None,
seed_api: bool = True,
caller: str = 'collector',
is_gymnasium: bool = False
) -> None:
"""
Overview:
Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
Expand All @@ -32,17 +39,20 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
The `cfg` parameter must contain `env_id`.
Arguments:
- env (:obj:`gym.Env`): An environment instance to be wrapped.
- env (:obj:`Union[gym.Env, gymnasium.Env]`): An environment instance to be wrapped.
- cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
- seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
- caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
``evaluator``. Different caller may need different wrappers. Default is 'collector'.
- is_gymnasium (:obj:`bool`): Whether the environment is a gymnasium environment. Defaults to False, i.e., \
the environment is a gym environment.
"""
self._env = None
self._raw_env = env
self._cfg = cfg
self._seed_api = seed_api # some env may disable `env.seed` api
self._caller = caller

if self._cfg is None:
self._cfg = {}
self._cfg = EasyDict(self._cfg)
Expand All @@ -55,6 +65,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
if 'env_id' not in self._cfg:
self._cfg.env_id = None
if env is not None:
self._is_gymnasium = isinstance(env, gymnasium.Env)
self._env = env
self._wrap_env(caller)
self._observation_space = self._env.observation_space
Expand All @@ -66,6 +77,7 @@ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True,
self._init_flag = True
else:
assert 'env_id' in self._cfg
self._is_gymnasium = is_gymnasium
self._init_flag = False
self._observation_space = None
self._action_space = None
Expand All @@ -82,7 +94,8 @@ def reset(self) -> np.ndarray:
- obs (:obj:`Dict`): The new observation after reset.
"""
if not self._init_flag:
self._env = gym.make(self._cfg.env_id)
gym_proxy = gymnasium if self._is_gymnasium else gym
self._env = gym_proxy.make(self._cfg.env_id)
self._wrap_env(self._caller)
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
Expand All @@ -98,29 +111,16 @@ def reset(self) -> np.ndarray:
name_prefix='rl-video-{}'.format(id(self))
)
self._replay_path = None
if isinstance(self._env, gym.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
elif isinstance(self._env, gymnasium.Env):
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
self._action_space.seed(self._seed + np_seed)
obs = self._env.reset(seed=self._seed + np_seed)
elif hasattr(self, '_seed'):
self._action_space.seed(self._seed)
obs = self._env.reset(seed=self._seed)
else:
obs = self._env.reset()
else:
raise RuntimeError("not support env type: {}".format(type(self._env)))
if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
np_seed = 100 * np.random.randint(1, 1000)
if self._seed_api:
self._env.seed(self._seed + np_seed)
self._action_space.seed(self._seed + np_seed)
elif hasattr(self, '_seed'):
if self._seed_api:
self._env.seed(self._seed)
self._action_space.seed(self._seed)
obs = self._env.reset()
if self.observation_space.dtype == np.float32:
obs = to_ndarray(obs, dtype=np.float32)
else:
Expand Down Expand Up @@ -221,7 +221,7 @@ def random_action(self) -> np.ndarray:
random_action = self.action_space.sample()
if isinstance(random_action, np.ndarray):
pass
elif isinstance(random_action, int):
elif isinstance(random_action, (int, np.int64)):
random_action = to_ndarray([random_action], dtype=np.int64)
elif isinstance(random_action, dict):
random_action = to_ndarray(random_action)
Expand All @@ -241,6 +241,8 @@ def _wrap_env(self, caller: str = 'collector') -> None:
- caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \
Different caller may need different wrappers. Default is 'collector'.
"""
if self._is_gymnasium:
self._env = GymToGymnasiumWrapper(self._env)
# wrapper_cfgs: Union[str, List]
wrapper_cfgs = self._cfg.env_wrapper
if isinstance(wrapper_cfgs, str):
Expand Down Expand Up @@ -362,4 +364,4 @@ def clone(self, caller: str = 'collector') -> BaseEnv:
raw_env.__setattr__('spec', spec)
except Exception:
raw_env = self._raw_env
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)
return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller, self._is_gymnasium)
22 changes: 22 additions & 0 deletions ding/envs/env/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import gym
import gymnasium
import numpy as np
import pytest
from easydict import EasyDict
Expand Down Expand Up @@ -68,6 +69,27 @@ def test_cartpole_pendulum(self, env_id):
# assert isinstance(action, np.ndarray)
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))

@pytest.mark.unittest
@pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v1'])
def test_cartpole_pendulum_gymnasium(self, env_id):
env = gymnasium.make(env_id)
ding_env = DingEnvWrapper(env=env)
print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
cfg = EasyDict(dict(
collector_env_num=16,
evaluator_env_num=3,
is_train=True,
))
l1 = ding_env.create_collector_env_cfg(cfg)
assert isinstance(l1, list)
l1 = ding_env.create_evaluator_env_cfg(cfg)
assert isinstance(l1, list)
obs = ding_env.reset()
assert isinstance(obs, np.ndarray)
action = ding_env.random_action()
# assert isinstance(action, np.ndarray)
print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))

@pytest.mark.envtest
def test_mujoco(self):
env_cfg = EasyDict(
Expand Down
17 changes: 14 additions & 3 deletions ding/envs/env_wrappers/env_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ class GymToGymnasiumWrapper(gym.Wrapper):
Overview:
This class is used to wrap a gymnasium environment to a gym environment.
Interfaces:
__init__, seed, reset
__init__, seed, reset, step
"""

def __init__(self, env: gymnasium.Env) -> None:
Expand Down Expand Up @@ -1470,9 +1470,20 @@ def reset(self) -> np.ndarray:
- observation (:obj:`np.ndarray`): The new observation after reset.
"""
if self.seed is not None:
return self.env.reset(seed=self._seed)
obs, info = self.env.reset(seed=self._seed)
else:
return self.env.reset()
obs, info = self.env.reset()
return obs

def step(self, *args, **kwargs):
"""
Overview:
Execute the given action in the environment, and return the new observation,
reward, done status, and info. To keep consistency with gym, the done status should be the either \
terminated=True or truncated=True.
"""
obs, rew, terminated, truncated, info = self.env.step(*args, **kwargs)
return obs, rew, terminated or truncated, info


@ENV_WRAPPER_REGISTRY.register('reward_in_obs')
Expand Down
49 changes: 49 additions & 0 deletions ding/example/dqn_nstep_gymnasium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import gymnasium as gym
from ditk import logging
from ding.model import DQN
from ding.policy import DQNPolicy
from ding.envs import DingEnvWrapper, BaseEnvManagerV2
from ding.data import DequeBuffer
from ding.config import compile_config
from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
from ding.utils import set_pkg_seed
from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config


def main():
logging.getLogger().setLevel(logging.INFO)
main_config.exp_name = 'cartpole_dqn_nstep_gymnasium'
main_config.policy.nstep = 3
cfg = compile_config(main_config, create_cfg=create_config, auto=True)
with task.start(async_mode=False, ctx=OnlineRLContext()):
collector_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
cfg=cfg.env.manager
)

set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

model = DQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNPolicy(cfg.policy, model=model)

task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
task.use(eps_greedy_handler(cfg))
task.use(StepCollector(cfg, policy.collect_mode, collector_env))
task.use(nstep_reward_enhancer(cfg))
task.use(data_pusher(cfg, buffer_))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
task.use(final_ctx_saver(cfg.exp_name))
task.run()


if __name__ == "__main__":
main()

0 comments on commit 96ccaed

Please sign in to comment.