Skip to content

Commit

Permalink
[Feature] Added support for vector-based rewards from environments in…
Browse files Browse the repository at this point in the history
… MO-Gymnasium (#992)

Co-authored-by: vmoens <[email protected]>
  • Loading branch information
dennismalmgren and vmoens authored May 5, 2023
1 parent 24abc75 commit a3f6969
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 24 deletions.
2 changes: 2 additions & 0 deletions .circleci/unittest/linux/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
fi
echo "installing gymnasium"
pip install "gymnasium[atari,accept-rom-license]"
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
else
pip install "gymnasium[atari,accept-rom-license]"
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
fi
1 change: 1 addition & 0 deletions .circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ do
else
pip install gymnasium[atari]
fi
pip install mo-gymnasium

$DIR/run_test.sh

Expand Down
2 changes: 2 additions & 0 deletions .circleci/unittest/linux_stable/scripts/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then
fi
echo "installing gymnasium"
pip install "gymnasium[atari,accept-rom-license]"
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
else
pip install "gymnasium[atari,accept-rom-license]"
pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py
fi
2 changes: 2 additions & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ the following function will return ``1`` when queried:
dm_control.DMControlWrapper
gym.GymEnv
gym.GymWrapper
gym.MOGymEnv
gym.MOGymWrapper
gym.set_gym_backend
gym.gym_backend
habitat.HabitatEnv
Expand Down
73 changes: 55 additions & 18 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import importlib

import time
from sys import platform
from typing import Optional, Union
Expand Down Expand Up @@ -37,7 +39,14 @@
)
from torchrl.envs.libs.brax import _has_brax, BraxEnv
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper
from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper
from torchrl.envs.libs.gym import (
_has_gym,
_is_from_pixels,
GymEnv,
GymWrapper,
MOGymEnv,
MOGymWrapper,
)
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.openml import OpenMLEnv
Expand All @@ -46,24 +55,12 @@
from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv
from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator

D4RL_ERR = None
try:
import d4rl # noqa
_has_d4rl = importlib.util.find_spec("d4rl") is not None

_has_d4rl = True
except Exception as err:
# many things can wrong when importing d4rl :(
_has_d4rl = False
D4RL_ERR = err
_has_mo = importlib.util.find_spec("mo_gymnasium") is not None

SKLEARN_ERR = None
try:
import sklearn # noqa
_has_sklearn = importlib.util.find_spec("sklearn") is not None

_has_sklearn = True
except ModuleNotFoundError as err:
_has_sklearn = False
SKLEARN_ERR = err

if _has_gym:
try:
Expand Down Expand Up @@ -212,6 +209,46 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only):
)
check_env_specs(env)

@pytest.mark.parametrize("frame_skip", [1, 3])
@pytest.mark.parametrize(
"from_pixels,pixels_only",
[
[False, False],
[True, True],
[True, False],
],
)
@pytest.mark.parametrize("wrapper", [True, False])
def test_mo(self, frame_skip, from_pixels, pixels_only, wrapper):
if importlib.util.find_spec("gymnasium") is not None and not _has_mo:
raise pytest.skip("mo-gym not found")
else:
# avoid skipping, which we consider as errors in the gym CI
return

def make_env():
import mo_gymnasium

if wrapper:
return MOGymWrapper(
mo_gymnasium.make("minecart-v0"),
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
)
else:
return MOGymEnv(
"minecart-v0",
frame_skip=frame_skip,
from_pixels=from_pixels,
pixels_only=pixels_only,
)

env = make_env()
check_env_specs(env)
env = SerialEnv(2, make_env)
check_env_specs(env)

def test_info_reader(self):
try:
import gym_super_mario_bros as mario_gym
Expand Down Expand Up @@ -1240,7 +1277,7 @@ def make_vmas():
assert env.rollout(max_steps=3).device == devices[1 - first]


@pytest.mark.skipif(not _has_d4rl, reason=f"D4RL not found: {D4RL_ERR}")
@pytest.mark.skipif(not _has_d4rl, reason="D4RL not found")
class TestD4RL:
@pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"])
def test_terminate_on_end(self, task):
Expand Down Expand Up @@ -1333,7 +1370,7 @@ def test_d4rl_iteration(self, task, split_trajs):
print(f"completed test after {time.time()-t0}s")


@pytest.mark.skipif(not _has_sklearn, reason=f"Scikit-learn not found: {SKLEARN_ERR}")
@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found")
@pytest.mark.parametrize(
"dataset",
[
Expand Down
2 changes: 1 addition & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def test_set_gym_environments_no_version_gymnasium_found():

# this version of gymnasium does not exist in implement_for
# therefore, set_gym_backend will not set anything and raise an ImportError.
msg = f"could not set anything related to gym backed {gymnasium_name} with version={gymnasium_version}."
msg = f"could not set anything related to gym backend {gymnasium_name} with version={gymnasium_version}."
with pytest.raises(ImportError, match=msg) as exc_info:
with set_gym_backend(gymnasium):
_utils_internal._set_gym_environments()
Expand Down
67 changes: 62 additions & 5 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
if not _has_gym:
_has_gym = importlib.util.find_spec("gymnasium") is not None

_has_mo = importlib.util.find_spec("mo_gymnasium") is not None


class set_gym_backend(_DecoratorContextManager):
"""Sets the gym-backend to a certain value.
Expand Down Expand Up @@ -106,7 +108,8 @@ def _call(self):
found_setter = True
if not found_setter:
raise ImportError(
f"could not set anything related to gym backed {self.backend.__name__} with version={self.backend.__version__}."
f"could not set anything related to gym backend "
f"{self.backend.__name__} with version={self.backend.__version__}."
)

def __enter__(self):
Expand Down Expand Up @@ -527,10 +530,17 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821
else:
observation_spec = CompositeSpec(observation=observation_spec)
self.observation_spec = observation_spec
self.reward_spec = UnboundedContinuousTensorSpec(
shape=[1],
device=self.device,
)
if hasattr(env, "reward_space") and env.reward_space is not None:
self.reward_spec = _gym_to_torchrl_spec_transform(
env.reward_space,
device=self.device,
categorical_action_encoding=self._categorical_action_encoding,
)
else:
self.reward_spec = UnboundedContinuousTensorSpec(
shape=[1],
device=self.device,
)

def _init_env(self):
self.reset()
Expand Down Expand Up @@ -671,3 +681,50 @@ def _check_kwargs(self, kwargs: Dict):

def __repr__(self) -> str:
return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})"


class MOGymWrapper(GymWrapper):
"""FARAMA MO-Gymnasium environment wrapper.
Examples:
>>> import mo_gymnasium as mo_gym
>>> env = MOGymWrapper(mo_gym.make('minecart-v0'), frame_skip=4)
>>> td = env.rand_step()
>>> print(td)
>>> print(env.available_envs)
"""

git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
libname = "mo-gymnasium"

_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)


class MOGymEnv(GymEnv):
"""FARAMA MO-Gymnasium environment wrapper.
Examples:
>>> env = MOGymEnv(env_name="minecart-v0", frame_skip=4)
>>> td = env.rand_step()
>>> print(td)
>>> print(env.available_envs)
"""

git_url = "https://github.com/Farama-Foundation/MO-Gymnasium"
libname = "mo-gymnasium"

@property
def lib(self) -> ModuleType:
if _has_mo:
import mo_gymnasium as mo_gym

return mo_gym
else:
try:
import mo_gymnasium # noqa: F401
except ImportError as err:
raise ImportError("MO-gymnasium not found, check installation") from err

_make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)

0 comments on commit a3f6969

Please sign in to comment.