Skip to content

Commit

Permalink
Added support for new PettingZoo API (#751)
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus28 authored Oct 2, 2022
1 parent b0c8d28 commit 128feb6
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
2 changes: 0 additions & 2 deletions test/pettingzoo/test_pistonball.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pprint

import pytest
from pistonball import get_args, train_agent, watch


@pytest.mark.skip(reason="TODO(Markus28): fix later")
def test_piston_ball(args=get_args()):
if args.watch:
watch(args)
Expand Down
2 changes: 0 additions & 2 deletions test/pettingzoo/test_tic_tac_toe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import pprint

import pytest
from tic_tac_toe import get_args, train_agent, watch


@pytest.mark.skip(reason="TODO(Markus28): fix later")
def test_tic_tac_toe(args=get_args()):
if args.watch:
watch(args)
Expand Down
40 changes: 35 additions & 5 deletions tianshou/env/pettingzoo_env.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import warnings
from abc import ABC
from typing import Any, Dict, List, Tuple, Union

import gym.spaces
import pettingzoo
from packaging import version
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper

if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
warnings.warn(
f"You are using PettingZoo {pettingzoo.__version__}. "
f"Future tianshou versions may not support PettingZoo<1.21.0. "
f"Consider upgrading your PettingZoo version.", DeprecationWarning
)


class PettingZooEnv(AECEnv, ABC):
"""The interface for petting zoo environments.
Expand Down Expand Up @@ -57,7 +67,20 @@ def __init__(self, env: BaseWrapper):

def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
self.env.reset(*args, **kwargs)
observation, _, _, info = self.env.last(self)

# Here, we do not label the return values explicitly to keep compatibility with
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
last_return = self.env.last(self)

if len(last_return) == 4:
warnings.warn(
"The PettingZoo environment is using the old step API. "
"This API may not be supported in future versions of tianshou. "
"We recommend that you update the environment code or apply a "
"compatibility wrapper.", DeprecationWarning
)

observation, info = last_return[0], last_return[-1]
if isinstance(observation, dict) and 'action_mask' in observation:
observation_dict = {
'agent_id': self.env.agent_selection,
Expand All @@ -83,9 +106,16 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
else:
return observation_dict

def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
def step(
self, action: Any
) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
Dict]]:
self.env.step(action)
observation, rew, done, info = self.env.last()

# Here, we do not label the return values explicitly to keep compatibility with
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
last_return = self.env.last()
observation = last_return[0]
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
Expand All @@ -105,15 +135,15 @@ def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:

for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
return obs, self.rewards, done, info
return (obs, self.rewards, *last_return[2:]) # type: ignore

def close(self) -> None:
self.env.close()

def seed(self, seed: Any = None) -> None:
try:
self.env.seed(seed)
except NotImplementedError:
except (NotImplementedError, AttributeError):
self.env.reset(seed=seed)

def render(self, mode: str = "human") -> Any:
Expand Down

0 comments on commit 128feb6

Please sign in to comment.