From 340e59e1525ac5ddccc1e861ce5caf9820eb9cad Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Fri, 8 Sep 2023 14:53:37 +0800 Subject: [PATCH 1/2] merge dm_control to gymnasium --- examples/dm_control/README.md | 2 +- examples/dm_control/train_ppo.py | 20 +++++++++++--------- openrl/envs/common/registration.py | 7 ------- openrl/envs/dmc/__init__.py | 30 ------------------------------ openrl/envs/dmc/dmc_env.py | 13 ------------- 5 files changed, 12 insertions(+), 60 deletions(-) delete mode 100644 openrl/envs/dmc/__init__.py delete mode 100644 openrl/envs/dmc/dmc_env.py diff --git a/examples/dm_control/README.md b/examples/dm_control/README.md index 5bbe63fa..ff0feb49 100644 --- a/examples/dm_control/README.md +++ b/examples/dm_control/README.md @@ -1,6 +1,6 @@ ## Installation ```bash -pip install shimmy[dm-control] +pip install "shimmy[dm-control]" ``` ## Usage diff --git a/examples/dm_control/train_ppo.py b/examples/dm_control/train_ppo.py index aa77222f..b79588e8 100644 --- a/examples/dm_control/train_ppo.py +++ b/examples/dm_control/train_ppo.py @@ -1,13 +1,15 @@ import numpy as np from gymnasium.wrappers import FlattenObservation +import torch from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.envs.wrappers.base_wrapper import BaseWrapper -from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper +from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper,ConvertEmptyBoxWrapper from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent + env_name = "dm_control/cartpole-balance-v0" # env_name = "dm_control/walker-walk-v0" @@ -18,15 +20,15 @@ def train(): cfg = cfg_parser.parse_args(["--config", "ppo.yaml"]) # create environment, set environment parallelism to 64 + env_num = 64 env = make( env_name, - env_num=64, - cfg=cfg, + env_num=env_num, asynchronous=True, - env_wrappers=[FrameSkip, FlattenObservation], + env_wrappers=[FrameSkip, FlattenObservation,ConvertEmptyBoxWrapper], ) - net = Net(env, cfg=cfg, device="cuda") + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") # initialize the trainer agent = Agent( net, @@ -44,18 +46,18 @@ def evaluation(): # begin to test # Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array. render_mode = "group_rgb_array" + env = make( env_name, render_mode=render_mode, env_num=4, asynchronous=True, - env_wrappers=[FrameSkip, FlattenObservation], - cfg=cfg, + env_wrappers=[FrameSkip, FlattenObservation,ConvertEmptyBoxWrapper], ) # Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5. - env = GIFWrapper(env, gif_path="./new.gif", fps=5) + # env = GIFWrapper(env, gif_path="./new.gif", fps=5) - net = Net(env, cfg=cfg, device="cuda") + net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") # initialize the trainer agent = Agent( net, diff --git a/openrl/envs/common/registration.py b/openrl/envs/common/registration.py index 11685c5d..099f2b39 100644 --- a/openrl/envs/common/registration.py +++ b/openrl/envs/common/registration.py @@ -78,13 +78,6 @@ def make( env_fns = make_snake_envs( id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs ) - - elif id.startswith("dm_control/"): - from openrl.envs.dmc import make_dmc_envs - - env_fns = make_dmc_envs( - id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs - ) elif id.startswith("GymV21Environment-v0:") or id.startswith( "GymV26Environment-v0:" ): diff --git a/openrl/envs/dmc/__init__.py b/openrl/envs/dmc/__init__.py deleted file mode 100644 index 4f6ff39e..00000000 --- a/openrl/envs/dmc/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -import copy -from typing import Callable, List, Optional, Union - -import dmc2gym - -from openrl.envs.common import build_envs -from openrl.envs.dmc.dmc_env import make - - -def make_dmc_envs( - id: str, - env_num: int = 1, - render_mode: Optional[Union[str, List[str]]] = None, - **kwargs, -): - from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper - from openrl.envs.wrappers.extra_wrappers import ConvertEmptyBoxWrapper - - env_wrappers = copy.copy(kwargs.pop("env_wrappers", [])) - env_wrappers += [ConvertEmptyBoxWrapper, RemoveTruncated, Single2MultiAgentWrapper] - env_fns = build_envs( - make=make, - id=id, - env_num=env_num, - render_mode=render_mode, - wrappers=env_wrappers, - **kwargs, - ) - - return env_fns diff --git a/openrl/envs/dmc/dmc_env.py b/openrl/envs/dmc/dmc_env.py deleted file mode 100644 index 2c295737..00000000 --- a/openrl/envs/dmc/dmc_env.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import Any, Optional - -import gymnasium as gym -import numpy as np - - -def make( - id: str, - render_mode: Optional[str] = None, - **kwargs: Any, -): - env = gym.make(id, render_mode=render_mode) - return env From 0491ad426e65187ca82aa77091083fadbf244702 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Fri, 8 Sep 2023 14:55:33 +0800 Subject: [PATCH 2/2] merge dm_control to gymnasium --- examples/dm_control/train_ppo.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/dm_control/train_ppo.py b/examples/dm_control/train_ppo.py index b79588e8..09812303 100644 --- a/examples/dm_control/train_ppo.py +++ b/examples/dm_control/train_ppo.py @@ -1,15 +1,18 @@ import numpy as np -from gymnasium.wrappers import FlattenObservation import torch +from gymnasium.wrappers import FlattenObservation from openrl.configs.config import create_config_parser from openrl.envs.common import make from openrl.envs.wrappers.base_wrapper import BaseWrapper -from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper,ConvertEmptyBoxWrapper +from openrl.envs.wrappers.extra_wrappers import ( + ConvertEmptyBoxWrapper, + FrameSkip, + GIFWrapper, +) from openrl.modules.common import PPONet as Net from openrl.runners.common import PPOAgent as Agent - env_name = "dm_control/cartpole-balance-v0" # env_name = "dm_control/walker-walk-v0" @@ -25,7 +28,7 @@ def train(): env_name, env_num=env_num, asynchronous=True, - env_wrappers=[FrameSkip, FlattenObservation,ConvertEmptyBoxWrapper], + env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper], ) net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") @@ -52,10 +55,10 @@ def evaluation(): render_mode=render_mode, env_num=4, asynchronous=True, - env_wrappers=[FrameSkip, FlattenObservation,ConvertEmptyBoxWrapper], + env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper], ) # Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5. - # env = GIFWrapper(env, gif_path="./new.gif", fps=5) + env = GIFWrapper(env, gif_path="./new.gif", fps=5) net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu") # initialize the trainer