Skip to content

Commit

Permalink
BUGFIX: Ensure puppeteer state is part of the PuppetPolicy state.
Browse files Browse the repository at this point in the history
Certain puppeteer functions (e.g. `GrimTwoResourceInTheMatrix`) are not pure functions, but are `Callable` objects that maintain their own state between `__call__` and update it as a side-effect. This means that puppeteer state is not exposed in the `PuppetPolicy` state returned by `initial_state` and transformed by `step`.

This causes undefined behavior in two ways:

1. Since puppeteer functions are shared between PuppetPolicy, multiple puppets running in the same Scenario will erroneously share a single puppeteer state between multiple puppets.
2. Puppeteer state is not reset at the start of a new episode.

This change fixes this by ensuring that puppeteer state is a part of the `PuppetPolicy` state and is reset on `step_type.last()`.

NOTE: This bug did not affect the experiments run in the Melting Pot paper.

fixes #70

PiperOrigin-RevId: 469167894
Change-Id: Iaee2243f326ca9fdf0670d72596a5e92f12e3655
  • Loading branch information
jagapiou authored and copybara-github committed Aug 22, 2022
1 parent 6ff20fa commit 7ee768f
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 50 deletions.
3 changes: 2 additions & 1 deletion docs/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Then structure its bot entry like this:
```python
my_puppet_bot_0=_puppet(
substrate='name_of_substrate_where_bot_operates',
puppeteer_fn=puppeteer_functions.name_of_puppeteer_function,
puppeteer_fn_builder=functools.partial(
puppeteer_functions.name_of_puppeteer_function, **kwargs)
),
```

Expand Down
7 changes: 4 additions & 3 deletions meltingpot/python/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_config(bot_name: str) -> config_dict.ConfigDict:
config = config_dict.create(
bot_name=bot_name,
substrate=bot.substrate,
puppeteer_fn=bot.puppeteer_fn,
puppeteer_fn_builder=bot.puppeteer_fn_builder,
saved_model_path=bot.model_path,
)
return config.lock()
Expand All @@ -49,8 +49,9 @@ def build(config: config_dict.ConfigDict) -> policy.Policy:
The bot policy.
"""
saved_model = policy.SavedModelPolicy(config.saved_model_path)
if config.puppeteer_fn:
if config.puppeteer_fn_builder:
return policy.PuppetPolicy(
puppeteer_fn=config.puppeteer_fn, puppet_policy=saved_model)
puppeteer_fn_builder=config.puppeteer_fn_builder,
puppet_policy=saved_model)
else:
return saved_model
42 changes: 28 additions & 14 deletions meltingpot/python/configs/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
"""Library of stored bots for MeltingPot scenarios."""

import dataclasses
import functools
import os
from typing import Mapping, Optional
from typing import Callable, Mapping, Optional

import immutabledict

Expand All @@ -38,11 +39,13 @@ class BotConfig:
Attributes:
substrate: the substrate the bot was trained for.
model_path: the path to the bot's saved model.
puppeteer_fn: an optional puppeteer function used to control the bot.
puppeteer_fn_builder: an optional function that returns the puppeteer
function used to control the bot.
"""
substrate: str
model_path: str
puppeteer_fn: Optional[puppeteer_functions.PuppeteerFn] = None
puppeteer_fn_builder: Optional[
Callable[[], puppeteer_functions.PuppeteerFn]] = None


def _saved_model(substrate: str,
Expand All @@ -60,18 +63,21 @@ def _saved_model(substrate: str,


def _puppet(substrate: str,
puppeteer_fn: puppeteer_functions.PuppeteerFn,
puppeteer_fn_builder: Callable[[], puppeteer_functions.PuppeteerFn],
models_root: str = MODELS_ROOT) -> BotConfig:
"""Returns the config for a puppet bot.
Args:
substrate: the substrate on which the bot was trained.
puppeteer_fn: the puppeteer function that controls the puppet.
puppeteer_fn_builder: function that returns the puppeteer function that
controls the puppet.
models_root: The path to the directory containing the saved_models.
"""
puppet_path = os.path.join(models_root, substrate, 'puppet')
return BotConfig(
substrate=substrate, model_path=puppet_path, puppeteer_fn=puppeteer_fn)
substrate=substrate,
model_path=puppet_path,
puppeteer_fn_builder=puppeteer_fn_builder)


BOT_CONFIGS: Mapping[str, BotConfig] = immutabledict.immutabledict(
Expand Down Expand Up @@ -222,7 +228,8 @@ def _puppet(substrate: str,
),
chicken_puppet_grim=_puppet(
substrate='chicken_in_the_matrix',
puppeteer_fn=puppeteer_functions.GrimTwoResourceInTheMatrix(2),
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=2),
),
chicken_pure_dove_0=_saved_model(
substrate='chicken_in_the_matrix',
Expand Down Expand Up @@ -338,19 +345,23 @@ def _puppet(substrate: str,
),
cleanup_puppet_alternate_clean_first=_puppet(
substrate='clean_up',
puppeteer_fn=puppeteer_functions.cleanup_alternate_clean_first,
puppeteer_fn_builder=(
lambda: puppeteer_functions.cleanup_alternate_clean_first),
),
cleanup_puppet_alternate_eat_first=_puppet(
substrate='clean_up',
puppeteer_fn=puppeteer_functions.cleanup_alternate_eat_first,
puppeteer_fn_builder=(
lambda: puppeteer_functions.cleanup_alternate_eat_first),
),
cleanup_puppet_reciprocator_threshold_low=_puppet(
substrate='clean_up',
puppeteer_fn=puppeteer_functions.ConditionalCleaner(1),
puppeteer_fn_builder=functools.partial(
puppeteer_functions.ConditionalCleaner, threshold=1),
),
cleanup_puppet_reciprocator_threshold_mid=_puppet(
substrate='clean_up',
puppeteer_fn=puppeteer_functions.ConditionalCleaner(2),
puppeteer_fn_builder=functools.partial(
puppeteer_functions.ConditionalCleaner, threshold=2),
),
closed_commons_zapper_0=_saved_model(
substrate='commons_harvest_closed',
Expand Down Expand Up @@ -582,11 +593,13 @@ def _puppet(substrate: str,
),
prisoners_dilemma_puppet_grim_threshold_high=_puppet(
substrate='prisoners_dilemma_in_the_matrix',
puppeteer_fn=puppeteer_functions.GrimTwoResourceInTheMatrix(2),
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=2),
),
prisoners_dilemma_puppet_grim_threshold_low=_puppet(
substrate='prisoners_dilemma_in_the_matrix',
puppeteer_fn=puppeteer_functions.GrimTwoResourceInTheMatrix(1),
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=1),
),
pure_coordination_type_1_specialist_0=_saved_model(
substrate='pure_coordination_in_the_matrix',
Expand Down Expand Up @@ -650,7 +663,8 @@ def _puppet(substrate: str,
),
stag_hunt_puppet_grim=_puppet(
substrate='stag_hunt_in_the_matrix',
puppeteer_fn=puppeteer_functions.GrimTwoResourceInTheMatrix(1),
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=1),
),
stag_hunt_stag_specialist_3=_saved_model(
substrate='stag_hunt_in_the_matrix',
Expand Down
101 changes: 69 additions & 32 deletions meltingpot/python/utils/bots/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import abc
import contextlib
from typing import Generic, Mapping, Tuple, TypeVar
from typing import Callable, Generic, Mapping, Tuple, TypeVar

import dm_env
import immutabledict
Expand All @@ -30,18 +30,26 @@


class Policy(Generic[State], metaclass=abc.ABCMeta):
"""Abstract base class for a policy."""
"""Abstract base class for a policy.
Must not possess any mutable state not in `initial_state`.
"""

@abc.abstractmethod
def initial_state(self) -> State:
"""Returns the initial state of the agent."""
"""Returns the initial state of the agent.
Must not have any side effects.
"""
raise NotImplementedError()

@abc.abstractmethod
def step(self, timestep: dm_env.TimeStep,
prev_state: State) -> Tuple[int, State]:
"""Steps the agent.
Must not have any side effects.
Args:
timestep: information from the environment
prev_state: the previous state of the agent.
Expand Down Expand Up @@ -248,53 +256,82 @@ def close(self) -> None:
_GOAL_OBS_NAME = 'GOAL'


class PuppetPolicy(Policy[State], Generic[State]):
"""A puppet policy controlled by a puppeteer function."""
class _Puppeteer(Generic[State]):
"""A puppeteer that controls the timestep forwarded to the puppet."""

def __init__(self, puppeteer_fn: puppeteer_functions.PuppeteerFn,
puppet_policy: Policy) -> None:
"""Creates a new PuppetBot.
def __init__(
self,
puppeteer_fn_builder: Callable[[], puppeteer_functions.PuppeteerFn],
) -> None:
"""Initializes the instance.
Args:
puppeteer_fn: The puppeteer function. This will be called at every step to
obtain the goal of that step for the underlying puppet.
puppet_policy: The puppet policy. Will be closed with this wrapper.
puppeteer_fn_builder: Builds the puppeteer function that will be called at
every step to obtain the goal of that step for the underlying puppet.
"""
self._puppeteer_fn = puppeteer_fn
self._puppet = puppet_policy
self._puppeteer_fn_builder = puppeteer_fn_builder

def initial_state(self) -> State:
"""Returns the initial state of the puppeteer."""
step_count = 0
puppeteer_fn = self._puppeteer_fn_builder()
return (step_count, puppeteer_fn)

def _puppeteer_initial_state(self) -> int:
return 0
def step(self, timestep: dm_env.TimeStep,
prev_state: State) -> Tuple[dm_env.TimeStep, State]:
"""Steps the puppeteer.
Args:
timestep: information from the environment.
prev_state: the previous state of the puppeteer.
Returns:
timestep: the timestep to forward to the puppet.
next_state: the state for the next step call.
"""
step_count, puppeteer_fn = prev_state
goal = puppeteer_fn(step_count, timestep.observation)
if timestep.step_type == dm_env.StepType.LAST:
next_state = self.initial_state()
else:
next_state = (step_count + 1, puppeteer_fn)

def _puppeteer_step(self, timestep: dm_env.TimeStep,
prev_state: int) -> Tuple[dm_env.TimeStep, int]:
"""Returns the transformed observation for the puppet step."""
goal = self._puppeteer_fn(prev_state, timestep.observation)
next_state = prev_state + 1
puppet_observation = immutabledict.immutabledict(
timestep.observation, **{_GOAL_OBS_NAME: goal})
puppet_timestep = timestep._replace(observation=puppet_observation)
return puppet_timestep, next_state


class PuppetPolicy(Policy[State], Generic[State]):
"""A puppet policy controlled by a puppeteer function."""

def __init__(
self,
puppeteer_fn_builder: Callable[[], puppeteer_functions.PuppeteerFn],
puppet_policy: Policy) -> None:
"""Creates a new PuppetBot.
Args:
puppeteer_fn_builder: Builds the puppeteer function that will be called at
every step to obtain the goal of that step for the underlying puppet.
puppet_policy: The puppet policy. Will be closed with this wrapper.
"""
self._puppeteer = _Puppeteer(puppeteer_fn_builder)
self._puppet = puppet_policy

def step(self, timestep: dm_env.TimeStep,
prev_state: State) -> Tuple[int, State]:
"""See base class."""
puppet_timestep, puppeteer_state = self._puppeteer_step(
timestep, prev_state['puppeteer'])
action, puppet_state = self._puppet.step(puppet_timestep,
prev_state['puppet'])
next_state = {
'puppeteer': puppeteer_state,
'puppet': puppet_state,
}
puppeteer_state, puppet_state = prev_state
puppet_timestep, puppeteer_state = self._puppeteer.step(
timestep, puppeteer_state)
action, puppet_state = self._puppet.step(puppet_timestep, puppet_state)
next_state = (puppeteer_state, puppet_state)
return action, next_state

def initial_state(self) -> State:
"""See base class."""
return {
'puppeteer': 0,
'puppet': self._puppet.initial_state(),
}
return (self._puppeteer.initial_state(), self._puppet.initial_state())

def close(self) -> None:
"""See base class."""
Expand Down

0 comments on commit 7ee768f

Please sign in to comment.