From a3f69692945cd50ccf44a34ce35376ee0dc4009d Mon Sep 17 00:00:00 2001 From: dennismalmgren Date: Fri, 5 May 2023 23:31:22 +0200 Subject: [PATCH] [Feature] Added support for vector-based rewards from environments in MO-Gymnasium (#992) Co-authored-by: vmoens --- .circleci/unittest/linux/scripts/setup_env.sh | 2 + .../linux_libs/scripts_gym/batch_scripts.sh | 1 + .../linux_stable/scripts/setup_env.sh | 2 + docs/source/reference/envs.rst | 2 + test/test_libs.py | 73 ++++++++++++++----- test/test_utils.py | 2 +- torchrl/envs/libs/gym.py | 67 +++++++++++++++-- 7 files changed, 125 insertions(+), 24 deletions(-) diff --git a/.circleci/unittest/linux/scripts/setup_env.sh b/.circleci/unittest/linux/scripts/setup_env.sh index 5d17d3eaec6..1d02f8ecd0c 100755 --- a/.circleci/unittest/linux/scripts/setup_env.sh +++ b/.circleci/unittest/linux/scripts/setup_env.sh @@ -115,6 +115,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi echo "installing gymnasium" pip install "gymnasium[atari,accept-rom-license]" + pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py else pip install "gymnasium[atari,accept-rom-license]" + pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py fi diff --git a/.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh b/.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh index b5c90f49aa4..4ade3c2bbc9 100755 --- a/.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh +++ b/.circleci/unittest/linux_libs/scripts_gym/batch_scripts.sh @@ -141,6 +141,7 @@ do else pip install gymnasium[atari] fi + pip install mo-gymnasium $DIR/run_test.sh diff --git a/.circleci/unittest/linux_stable/scripts/setup_env.sh b/.circleci/unittest/linux_stable/scripts/setup_env.sh index 5d17d3eaec6..1d02f8ecd0c 100755 --- a/.circleci/unittest/linux_stable/scripts/setup_env.sh +++ b/.circleci/unittest/linux_stable/scripts/setup_env.sh @@ -115,6 +115,8 @@ if [[ $OSTYPE != 'darwin'* ]]; then fi echo "installing gymnasium" pip install "gymnasium[atari,accept-rom-license]" + pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py else pip install "gymnasium[atari,accept-rom-license]" + pip install mo-gymnasium[mujoco] # requires here bc needs mujoco-py fi diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 1864b496547..c247db925e2 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -481,6 +481,8 @@ the following function will return ``1`` when queried: dm_control.DMControlWrapper gym.GymEnv gym.GymWrapper + gym.MOGymEnv + gym.MOGymWrapper gym.set_gym_backend gym.gym_backend habitat.HabitatEnv diff --git a/test/test_libs.py b/test/test_libs.py index 07864e229da..f92394ecd01 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import argparse +import importlib + import time from sys import platform from typing import Optional, Union @@ -37,7 +39,14 @@ ) from torchrl.envs.libs.brax import _has_brax, BraxEnv from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv, DMControlWrapper -from torchrl.envs.libs.gym import _has_gym, _is_from_pixels, GymEnv, GymWrapper +from torchrl.envs.libs.gym import ( + _has_gym, + _is_from_pixels, + GymEnv, + GymWrapper, + MOGymEnv, + MOGymWrapper, +) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv from torchrl.envs.libs.openml import OpenMLEnv @@ -46,24 +55,12 @@ from torchrl.envs.vec_env import _has_envpool, MultiThreadedEnvWrapper, SerialEnv from torchrl.modules import ActorCriticOperator, MLP, SafeModule, ValueOperator -D4RL_ERR = None -try: - import d4rl # noqa +_has_d4rl = importlib.util.find_spec("d4rl") is not None - _has_d4rl = True -except Exception as err: - # many things can wrong when importing d4rl :( - _has_d4rl = False - D4RL_ERR = err +_has_mo = importlib.util.find_spec("mo_gymnasium") is not None -SKLEARN_ERR = None -try: - import sklearn # noqa +_has_sklearn = importlib.util.find_spec("sklearn") is not None - _has_sklearn = True -except ModuleNotFoundError as err: - _has_sklearn = False - SKLEARN_ERR = err if _has_gym: try: @@ -212,6 +209,46 @@ def test_gym_fake_td(self, env_name, frame_skip, from_pixels, pixels_only): ) check_env_specs(env) + @pytest.mark.parametrize("frame_skip", [1, 3]) + @pytest.mark.parametrize( + "from_pixels,pixels_only", + [ + [False, False], + [True, True], + [True, False], + ], + ) + @pytest.mark.parametrize("wrapper", [True, False]) + def test_mo(self, frame_skip, from_pixels, pixels_only, wrapper): + if importlib.util.find_spec("gymnasium") is not None and not _has_mo: + raise pytest.skip("mo-gym not found") + else: + # avoid skipping, which we consider as errors in the gym CI + return + + def make_env(): + import mo_gymnasium + + if wrapper: + return MOGymWrapper( + mo_gymnasium.make("minecart-v0"), + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + ) + else: + return MOGymEnv( + "minecart-v0", + frame_skip=frame_skip, + from_pixels=from_pixels, + pixels_only=pixels_only, + ) + + env = make_env() + check_env_specs(env) + env = SerialEnv(2, make_env) + check_env_specs(env) + def test_info_reader(self): try: import gym_super_mario_bros as mario_gym @@ -1240,7 +1277,7 @@ def make_vmas(): assert env.rollout(max_steps=3).device == devices[1 - first] -@pytest.mark.skipif(not _has_d4rl, reason=f"D4RL not found: {D4RL_ERR}") +@pytest.mark.skipif(not _has_d4rl, reason="D4RL not found") class TestD4RL: @pytest.mark.parametrize("task", ["walker2d-medium-replay-v2"]) def test_terminate_on_end(self, task): @@ -1333,7 +1370,7 @@ def test_d4rl_iteration(self, task, split_trajs): print(f"completed test after {time.time()-t0}s") -@pytest.mark.skipif(not _has_sklearn, reason=f"Scikit-learn not found: {SKLEARN_ERR}") +@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found") @pytest.mark.parametrize( "dataset", [ diff --git a/test/test_utils.py b/test/test_utils.py index 1bd8cc45082..6a44226d780 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -255,7 +255,7 @@ def test_set_gym_environments_no_version_gymnasium_found(): # this version of gymnasium does not exist in implement_for # therefore, set_gym_backend will not set anything and raise an ImportError. - msg = f"could not set anything related to gym backed {gymnasium_name} with version={gymnasium_version}." + msg = f"could not set anything related to gym backend {gymnasium_name} with version={gymnasium_version}." with pytest.raises(ImportError, match=msg) as exc_info: with set_gym_backend(gymnasium): _utils_internal._set_gym_environments() diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index d3046f1dc2f..5632065c57d 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -40,6 +40,8 @@ if not _has_gym: _has_gym = importlib.util.find_spec("gymnasium") is not None +_has_mo = importlib.util.find_spec("mo_gymnasium") is not None + class set_gym_backend(_DecoratorContextManager): """Sets the gym-backend to a certain value. @@ -106,7 +108,8 @@ def _call(self): found_setter = True if not found_setter: raise ImportError( - f"could not set anything related to gym backed {self.backend.__name__} with version={self.backend.__version__}." + f"could not set anything related to gym backend " + f"{self.backend.__name__} with version={self.backend.__version__}." ) def __enter__(self): @@ -527,10 +530,17 @@ def _make_specs(self, env: "gym.Env") -> None: # noqa: F821 else: observation_spec = CompositeSpec(observation=observation_spec) self.observation_spec = observation_spec - self.reward_spec = UnboundedContinuousTensorSpec( - shape=[1], - device=self.device, - ) + if hasattr(env, "reward_space") and env.reward_space is not None: + self.reward_spec = _gym_to_torchrl_spec_transform( + env.reward_space, + device=self.device, + categorical_action_encoding=self._categorical_action_encoding, + ) + else: + self.reward_spec = UnboundedContinuousTensorSpec( + shape=[1], + device=self.device, + ) def _init_env(self): self.reset() @@ -671,3 +681,50 @@ def _check_kwargs(self, kwargs: Dict): def __repr__(self) -> str: return f"{self.__class__.__name__}(env={self.env_name}, batch_size={self.batch_size}, device={self.device})" + + +class MOGymWrapper(GymWrapper): + """FARAMA MO-Gymnasium environment wrapper. + + Examples: + >>> import mo_gymnasium as mo_gym + >>> env = MOGymWrapper(mo_gym.make('minecart-v0'), frame_skip=4) + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + + """ + + git_url = "https://github.com/Farama-Foundation/MO-Gymnasium" + libname = "mo-gymnasium" + + _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs) + + +class MOGymEnv(GymEnv): + """FARAMA MO-Gymnasium environment wrapper. + + Examples: + >>> env = MOGymEnv(env_name="minecart-v0", frame_skip=4) + >>> td = env.rand_step() + >>> print(td) + >>> print(env.available_envs) + + """ + + git_url = "https://github.com/Farama-Foundation/MO-Gymnasium" + libname = "mo-gymnasium" + + @property + def lib(self) -> ModuleType: + if _has_mo: + import mo_gymnasium as mo_gym + + return mo_gym + else: + try: + import mo_gymnasium # noqa: F401 + except ImportError as err: + raise ImportError("MO-gymnasium not found, check installation") from err + + _make_specs = set_gym_backend("gymnasium")(GymEnv._make_specs)