Skip to content

Commit

Permalink
fix a bug in batch._is_batch_set (#825)
Browse files Browse the repository at this point in the history
- [ ] I have marked all applicable categories:
    + [x] exception-raising fix
    + [ ] algorithm implementation fix
    + [ ] documentation modification
    + [ ] new feature
- [ ] I have reformatted the code using `make format` (**required**)
- [ ] I have checked the code using `make commit-checks` (**required**)
- [ ] If applicable, I have mentioned the relevant/related issue(s)
- [ ] If applicable, I have listed every items in this Pull Request
below

I'm developing a new PettingZoo environment. It is a two players turns
board game.

 ```
   obs_space = dict(
      board = gym.spaces.MultiBinary([8, 8]),
      player = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2),
      other_player = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2)
    )    
    self._observation_space = gym.spaces.Dict(spaces=obs_space)
    self._action_space = gym.spaces.Tuple([gym.spaces.Discrete(8)] * 2)
 ...

# this cache ensures that same space object is returned for the same
agent
  # allows action space seeding to work as expected
  @functools.lru_cache(maxsize=None)
  def observation_space(self, agent):
# gymnasium spaces are defined and documented here:
https://gymnasium.farama.org/api/spaces/
      return self._observation_space

  @functools.lru_cache(maxsize=None)
  def action_space(self, agent):
      return self._action_space

```

My test is:

```
def test_with_tianshou():

  action = None

# env = gym.make('qwertyenv/CollectCoins-v0', pieces=['rock', 'rock'])

  env = CollectCoinsEnv(pieces=['rock', 'rock'], with_mask=True)

  def another_action_taken(action_taken):
    nonlocal action
    action = action_taken

# Wrapping the original environment as to make sure a valid action will
be taken.
  env = EnsureValidAction(
      env,
      env.check_action_valid,
      env.provide_alternative_valid_action,
      another_action_taken
  )

  env = PettingZooEnv(env)

policies = MultiAgentPolicyManager([RandomPolicy(), RandomPolicy()],
env)

  env = DummyVectorEnv([lambda: env])

  collector = Collector(policies, env)

  result = collector.collect(n_step=200, render=0.1)


```

I have also a wrapper that may be redundant as of Tianshou capability to action_mask, yet it is still part of the code:

```
from typing import TypeVar, Callable
import gymnasium as gym
from pettingzoo.utils.wrappers import BaseWrapper

Action = TypeVar("Action")


class ActionWrapper(BaseWrapper):
  def __init__(self, env: gym.Env):
    super().__init__(env)

  def step(self, action):
    action = self.action(action)
    self.env.step(action)

  def action(self, action):
    pass

  def render(self, *args, **kwargs):
    self.env.render(*args, **kwargs)


class EnsureValidAction(ActionWrapper):
  """
A gym environment wrapper to help with the case that the agent wants to
take invalid actions.
For example consider a Chess game, where you let the action_space be any
piece moving to any square on the board,
but then when a wrong move is taken, instead of returing a big negative
reward, you just take another action,
this time a valid one. To make sure the learning algorithm is aware of
the action taken, a callback should be provided.
  """
  def __init__(self, env: gym.Env,
    check_action_valid: Callable[[Action], bool],
    provide_alternative_valid_action: Callable[[Action], Action],
    alternative_action_cb: Callable[[Action], None]):

    super().__init__(env)
    self.check_action_valid = check_action_valid
self.provide_alternative_valid_action = provide_alternative_valid_action
    self.alternative_action_cb = alternative_action_cb

  def action(self, action: Action) -> Action:
    if self.check_action_valid(action):
      return action
    alternative_action = self.provide_alternative_valid_action(action)
    self.alternative_action_cb(alternative_action)
    return alternative_action
  
```


To make above work I had to patch a bit PettingZoo (opened a pull-request there), and a small patch here (this PR).

Maybe I'm doing something wrong, yet I fail to see it.

With my both fixes of PZ and of Tianshou, I have two tests, one of the environment by itself, and the other as of above.
  • Loading branch information
zbenmo authored Mar 13, 2023
1 parent bc222e8 commit 73600ed
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tianshou/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def _is_batch_set(obj: Any) -> bool:
# "for element in obj" will just unpack the first dimension,
# but obj.tolist() will flatten ndarray of objects
# so do not use obj.tolist()
if obj.shape == ():
return False
return obj.dtype == object and \
all(isinstance(element, (dict, Batch)) for element in obj)
elif isinstance(obj, (list, tuple)):
Expand Down

0 comments on commit 73600ed

Please sign in to comment.