From 128feb677f0e269ddfc2592fad776a599d10b0a3 Mon Sep 17 00:00:00 2001 From: Markus Krimmel Date: Sun, 2 Oct 2022 18:33:12 +0200 Subject: [PATCH] Added support for new PettingZoo API (#751) --- test/pettingzoo/test_pistonball.py | 2 -- test/pettingzoo/test_tic_tac_toe.py | 2 -- tianshou/env/pettingzoo_env.py | 40 +++++++++++++++++++++++++---- 3 files changed, 35 insertions(+), 9 deletions(-) diff --git a/test/pettingzoo/test_pistonball.py b/test/pettingzoo/test_pistonball.py index 9e9c8bdb2..4a6c59655 100644 --- a/test/pettingzoo/test_pistonball.py +++ b/test/pettingzoo/test_pistonball.py @@ -1,10 +1,8 @@ import pprint -import pytest from pistonball import get_args, train_agent, watch -@pytest.mark.skip(reason="TODO(Markus28): fix later") def test_piston_ball(args=get_args()): if args.watch: watch(args) diff --git a/test/pettingzoo/test_tic_tac_toe.py b/test/pettingzoo/test_tic_tac_toe.py index 29b251b81..524cdb92a 100644 --- a/test/pettingzoo/test_tic_tac_toe.py +++ b/test/pettingzoo/test_tic_tac_toe.py @@ -1,10 +1,8 @@ import pprint -import pytest from tic_tac_toe import get_args, train_agent, watch -@pytest.mark.skip(reason="TODO(Markus28): fix later") def test_tic_tac_toe(args=get_args()): if args.watch: watch(args) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index 1722dc563..d1ab131de 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,10 +1,20 @@ +import warnings from abc import ABC from typing import Any, Dict, List, Tuple, Union import gym.spaces +import pettingzoo +from packaging import version from pettingzoo.utils.env import AECEnv from pettingzoo.utils.wrappers import BaseWrapper +if version.parse(pettingzoo.__version__) < version.parse("1.21.0"): + warnings.warn( + f"You are using PettingZoo {pettingzoo.__version__}. " + f"Future tianshou versions may not support PettingZoo<1.21.0. " + f"Consider upgrading your PettingZoo version.", DeprecationWarning + ) + class PettingZooEnv(AECEnv, ABC): """The interface for petting zoo environments. @@ -57,7 +67,20 @@ def __init__(self, env: BaseWrapper): def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: self.env.reset(*args, **kwargs) - observation, _, _, info = self.env.last(self) + + # Here, we do not label the return values explicitly to keep compatibility with + # old step API. TODO: Change once PettingZoo>=1.21.0 is required + last_return = self.env.last(self) + + if len(last_return) == 4: + warnings.warn( + "The PettingZoo environment is using the old step API. " + "This API may not be supported in future versions of tianshou. " + "We recommend that you update the environment code or apply a " + "compatibility wrapper.", DeprecationWarning + ) + + observation, info = last_return[0], last_return[-1] if isinstance(observation, dict) and 'action_mask' in observation: observation_dict = { 'agent_id': self.env.agent_selection, @@ -83,9 +106,16 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: else: return observation_dict - def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: + def step( + self, action: Any + ) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool, + Dict]]: self.env.step(action) - observation, rew, done, info = self.env.last() + + # Here, we do not label the return values explicitly to keep compatibility with + # old step API. TODO: Change once PettingZoo>=1.21.0 is required + last_return = self.env.last() + observation = last_return[0] if isinstance(observation, dict) and 'action_mask' in observation: obs = { 'agent_id': self.env.agent_selection, @@ -105,7 +135,7 @@ def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: for agent_id, reward in self.env.rewards.items(): self.rewards[self.agent_idx[agent_id]] = reward - return obs, self.rewards, done, info + return (obs, self.rewards, *last_return[2:]) # type: ignore def close(self) -> None: self.env.close() @@ -113,7 +143,7 @@ def close(self) -> None: def seed(self, seed: Any = None) -> None: try: self.env.seed(seed) - except NotImplementedError: + except (NotImplementedError, AttributeError): self.env.reset(seed=seed) def render(self, mode: str = "human") -> Any: