Skip to content

Commit

Permalink
Create a Puppeteer class.
Browse files Browse the repository at this point in the history
More principled fix for #70

PiperOrigin-RevId: 469493212
Change-Id: I3f1ab3917911a03b02b42f50ee40bd03684767fe
  • Loading branch information
jagapiou authored and copybara-github committed Aug 23, 2022
1 parent e05c704 commit 3515a62
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 301 deletions.
7 changes: 3 additions & 4 deletions docs/extending.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@ my_bot_0=_saved_model(
),
```

If instead your bot is a puppet, then select a puppeteer function from those
defined in [`puppeteer_functions`](https://github.com/deepmind/meltingpot/tree/main/meltingpot/python/utils/bots/puppeteer_functions.py).
If instead your bot is a puppet, then select a `Puppeteer` from those
defined in [`puppeteers`](https://github.com/deepmind/meltingpot/tree/main/meltingpot/python/utils/bots/puppeteers.py).
Then structure its bot entry like this:

```python
my_puppet_bot_0=_puppet(
substrate='name_of_substrate_where_bot_operates',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.name_of_puppeteer_function, **kwargs)
puppeteer_builder=functools.partial(puppeteers.name_of_puppeteer, **kwargs)
),
```

Expand Down
9 changes: 4 additions & 5 deletions meltingpot/python/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_config(bot_name: str) -> config_dict.ConfigDict:
config = config_dict.create(
bot_name=bot_name,
substrate=bot.substrate,
puppeteer_fn_builder=bot.puppeteer_fn_builder,
puppeteer_builder=bot.puppeteer_builder,
saved_model_path=bot.model_path,
)
return config.lock()
Expand All @@ -49,9 +49,8 @@ def build(config: config_dict.ConfigDict) -> policy.Policy:
The bot policy.
"""
saved_model = policy.SavedModelPolicy(config.saved_model_path)
if config.puppeteer_fn_builder:
return policy.PuppetPolicy(
puppeteer_fn_builder=config.puppeteer_fn_builder,
puppet_policy=saved_model)
if config.puppeteer_builder:
puppeteer = config.puppeteer_builder()
return policy.PuppetPolicy(puppeteer=puppeteer, puppet=saved_model)
else:
return saved_model
46 changes: 21 additions & 25 deletions meltingpot/python/configs/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import immutabledict

from meltingpot.python.utils.bots import puppeteer_functions
from meltingpot.python.utils.bots import puppeteers


def _find_models_root() -> str:
Expand All @@ -39,13 +39,12 @@ class BotConfig:
Attributes:
substrate: the substrate the bot was trained for.
model_path: the path to the bot's saved model.
puppeteer_fn_builder: an optional function that returns the puppeteer
function used to control the bot.
puppeteer_builder: an optional function that returns the puppeteer
used to control the bot.
"""
substrate: str
model_path: str
puppeteer_fn_builder: Optional[
Callable[[], puppeteer_functions.PuppeteerFn]] = None
puppeteer_builder: Optional[Callable[[], puppeteers.Puppeteer]] = None


def _saved_model(substrate: str,
Expand All @@ -63,21 +62,20 @@ def _saved_model(substrate: str,


def _puppet(substrate: str,
puppeteer_fn_builder: Callable[[], puppeteer_functions.PuppeteerFn],
puppeteer_builder: Callable[[], puppeteers.Puppeteer],
models_root: str = MODELS_ROOT) -> BotConfig:
"""Returns the config for a puppet bot.
Args:
substrate: the substrate on which the bot was trained.
puppeteer_fn_builder: function that returns the puppeteer function that
controls the puppet.
puppeteer_builder: returns the puppeteer that controls the puppet.
models_root: The path to the directory containing the saved_models.
"""
puppet_path = os.path.join(models_root, substrate, 'puppet')
return BotConfig(
substrate=substrate,
model_path=puppet_path,
puppeteer_fn_builder=puppeteer_fn_builder)
puppeteer_builder=puppeteer_builder)


BOT_CONFIGS: Mapping[str, BotConfig] = immutabledict.immutabledict(
Expand Down Expand Up @@ -228,8 +226,8 @@ def _puppet(substrate: str,
),
chicken_puppet_grim=_puppet(
substrate='chicken_in_the_matrix',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=2),
puppeteer_builder=functools.partial(
puppeteers.GrimTwoResourceInTheMatrix, threshold=2),
),
chicken_pure_dove_0=_saved_model(
substrate='chicken_in_the_matrix',
Expand Down Expand Up @@ -345,23 +343,21 @@ def _puppet(substrate: str,
),
cleanup_puppet_alternate_clean_first=_puppet(
substrate='clean_up',
puppeteer_fn_builder=(
lambda: puppeteer_functions.cleanup_alternate_clean_first),
puppeteer_builder=puppeteers.CleanupAlternateCleanFirst,
),
cleanup_puppet_alternate_eat_first=_puppet(
substrate='clean_up',
puppeteer_fn_builder=(
lambda: puppeteer_functions.cleanup_alternate_eat_first),
puppeteer_builder=puppeteers.CleanupAlternateEatFirst,
),
cleanup_puppet_reciprocator_threshold_low=_puppet(
substrate='clean_up',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.ConditionalCleaner, threshold=1),
puppeteer_builder=functools.partial(
puppeteers.ConditionalCleaner, threshold=1),
),
cleanup_puppet_reciprocator_threshold_mid=_puppet(
substrate='clean_up',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.ConditionalCleaner, threshold=2),
puppeteer_builder=functools.partial(
puppeteers.ConditionalCleaner, threshold=2),
),
closed_commons_zapper_0=_saved_model(
substrate='commons_harvest_closed',
Expand Down Expand Up @@ -593,13 +589,13 @@ def _puppet(substrate: str,
),
prisoners_dilemma_puppet_grim_threshold_high=_puppet(
substrate='prisoners_dilemma_in_the_matrix',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=2),
puppeteer_builder=functools.partial(
puppeteers.GrimTwoResourceInTheMatrix, threshold=2),
),
prisoners_dilemma_puppet_grim_threshold_low=_puppet(
substrate='prisoners_dilemma_in_the_matrix',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=1),
puppeteer_builder=functools.partial(
puppeteers.GrimTwoResourceInTheMatrix, threshold=1),
),
pure_coordination_type_1_specialist_0=_saved_model(
substrate='pure_coordination_in_the_matrix',
Expand Down Expand Up @@ -663,8 +659,8 @@ def _puppet(substrate: str,
),
stag_hunt_puppet_grim=_puppet(
substrate='stag_hunt_in_the_matrix',
puppeteer_fn_builder=functools.partial(
puppeteer_functions.GrimTwoResourceInTheMatrix, threshold=1),
puppeteer_builder=functools.partial(
puppeteers.GrimTwoResourceInTheMatrix, threshold=1),
),
stag_hunt_stag_specialist_3=_saved_model(
substrate='stag_hunt_in_the_matrix',
Expand Down
68 changes: 9 additions & 59 deletions meltingpot/python/utils/bots/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@

import abc
import contextlib
from typing import Callable, Generic, Mapping, Tuple, TypeVar
from typing import Generic, Mapping, Tuple, TypeVar

import dm_env
import immutabledict
import numpy as np
import tensorflow as tf
import tree

from meltingpot.python.utils.bots import permissive_model
from meltingpot.python.utils.bots import puppeteer_functions
from meltingpot.python.utils.bots import puppeteers

State = TypeVar('State')

Expand Down Expand Up @@ -253,71 +252,22 @@ def close(self) -> None:
SavedModelPolicy = TF1SavedModelPolicy


_GOAL_OBS_NAME = 'GOAL'


class _Puppeteer(Generic[State]):
"""A puppeteer that controls the timestep forwarded to the puppet."""

def __init__(
self,
puppeteer_fn_builder: Callable[[], puppeteer_functions.PuppeteerFn],
) -> None:
"""Initializes the instance.
Args:
puppeteer_fn_builder: Builds the puppeteer function that will be called at
every step to obtain the goal of that step for the underlying puppet.
"""
self._puppeteer_fn_builder = puppeteer_fn_builder

def initial_state(self) -> State:
"""Returns the initial state of the puppeteer."""
step_count = 0
puppeteer_fn = self._puppeteer_fn_builder()
return (step_count, puppeteer_fn)

def step(self, timestep: dm_env.TimeStep,
prev_state: State) -> Tuple[dm_env.TimeStep, State]:
"""Steps the puppeteer.
Args:
timestep: information from the environment.
prev_state: the previous state of the puppeteer.
Returns:
timestep: the timestep to forward to the puppet.
next_state: the state for the next step call.
"""
step_count, puppeteer_fn = prev_state
goal = puppeteer_fn(step_count, timestep.observation)
if timestep.step_type == dm_env.StepType.LAST:
next_state = self.initial_state()
else:
next_state = (step_count + 1, puppeteer_fn)

puppet_observation = immutabledict.immutabledict(
timestep.observation, **{_GOAL_OBS_NAME: goal})
puppet_timestep = timestep._replace(observation=puppet_observation)
return puppet_timestep, next_state


class PuppetPolicy(Policy[State], Generic[State]):
"""A puppet policy controlled by a puppeteer function."""

def __init__(
self,
puppeteer_fn_builder: Callable[[], puppeteer_functions.PuppeteerFn],
puppet_policy: Policy) -> None:
puppeteer: puppeteers.Puppeteer,
puppet: Policy) -> None:
"""Creates a new PuppetBot.
Args:
puppeteer_fn_builder: Builds the puppeteer function that will be called at
every step to obtain the goal of that step for the underlying puppet.
puppet_policy: The puppet policy. Will be closed with this wrapper.
puppeteer: Puppeteer that will be called at every step to modify the
timestep forwarded to the underlying puppet.
puppet: The puppet policy. Will be closed with this wrapper.
"""
self._puppeteer = _Puppeteer(puppeteer_fn_builder)
self._puppet = puppet_policy
self._puppeteer = puppeteer
self._puppet = puppet

def step(self, timestep: dm_env.TimeStep,
prev_state: State) -> Tuple[int, State]:
Expand Down
Loading

0 comments on commit 3515a62

Please sign in to comment.