Skip to content

Commit

Permalink
merge dm_control to gymnasium
Browse files Browse the repository at this point in the history
merge dm_control to gymnasium
  • Loading branch information
huangshiyu13 authored Sep 8, 2023
2 parents d378b77 + 0491ad4 commit 20cf3d3
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 59 deletions.
2 changes: 1 addition & 1 deletion examples/dm_control/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Installation
```bash
pip install shimmy[dm-control]
pip install "shimmy[dm-control]"
```

## Usage
Expand Down
21 changes: 13 additions & 8 deletions examples/dm_control/train_ppo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import numpy as np
import torch
from gymnasium.wrappers import FlattenObservation

from openrl.configs.config import create_config_parser
from openrl.envs.common import make
from openrl.envs.wrappers.base_wrapper import BaseWrapper
from openrl.envs.wrappers.extra_wrappers import FrameSkip, GIFWrapper
from openrl.envs.wrappers.extra_wrappers import (
ConvertEmptyBoxWrapper,
FrameSkip,
GIFWrapper,
)
from openrl.modules.common import PPONet as Net
from openrl.runners.common import PPOAgent as Agent

Expand All @@ -18,15 +23,15 @@ def train():
cfg = cfg_parser.parse_args(["--config", "ppo.yaml"])

# create environment, set environment parallelism to 64
env_num = 64
env = make(
env_name,
env_num=64,
cfg=cfg,
env_num=env_num,
asynchronous=True,
env_wrappers=[FrameSkip, FlattenObservation],
env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper],
)

net = Net(env, cfg=cfg, device="cuda")
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
# initialize the trainer
agent = Agent(
net,
Expand All @@ -44,18 +49,18 @@ def evaluation():
# begin to test
# Create an environment for testing and set the number of environments to interact with to 4. Set rendering mode to group_rgb_array.
render_mode = "group_rgb_array"

env = make(
env_name,
render_mode=render_mode,
env_num=4,
asynchronous=True,
env_wrappers=[FrameSkip, FlattenObservation],
cfg=cfg,
env_wrappers=[FrameSkip, FlattenObservation, ConvertEmptyBoxWrapper],
)
# Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
env = GIFWrapper(env, gif_path="./new.gif", fps=5)

net = Net(env, cfg=cfg, device="cuda")
net = Net(env, cfg=cfg, device="cuda" if torch.cuda.is_available() else "cpu")
# initialize the trainer
agent = Agent(
net,
Expand Down
7 changes: 0 additions & 7 deletions openrl/envs/common/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ def make(
env_fns = make_snake_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)

elif id.startswith("dm_control/"):
from openrl.envs.dmc import make_dmc_envs

env_fns = make_dmc_envs(
id=id, env_num=env_num, render_mode=convert_render_mode, **kwargs
)
elif id.startswith("GymV21Environment-v0:") or id.startswith(
"GymV26Environment-v0:"
):
Expand Down
30 changes: 0 additions & 30 deletions openrl/envs/dmc/__init__.py

This file was deleted.

13 changes: 0 additions & 13 deletions openrl/envs/dmc/dmc_env.py

This file was deleted.

0 comments on commit 20cf3d3

Please sign in to comment.