Skip to content

Commit

Permalink
Add registration functions for envs and intelligence backends (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
edmundmills authored Dec 19, 2023
1 parent 767ed20 commit 2dcaa08
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 42 deletions.
12 changes: 1 addition & 11 deletions chatarena/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
from ..config import BackendConfig
from .anthropic import Claude
from .base import IntelligenceBackend
from .base import BACKEND_REGISTRY, IntelligenceBackend, register_backend
from .cohere import CohereAIChat
from .hf_transformers import TransformersConversational
from .human import Human
from .openai import OpenAIChat

ALL_BACKENDS = [
Human,
OpenAIChat,
CohereAIChat,
TransformersConversational,
Claude,
]

BACKEND_REGISTRY = {backend.type_name: backend for backend in ALL_BACKENDS}


# Load a backend from a config dictionary
def load_backend(config: BackendConfig):
Expand Down
3 changes: 2 additions & 1 deletion chatarena/backends/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..message import SYSTEM_NAME as SYSTEM
from ..message import Message
from .base import IntelligenceBackend
from .base import IntelligenceBackend, register_backend

try:
import anthropic
Expand All @@ -25,6 +25,7 @@
DEFAULT_MODEL = "claude-v1"


@register_backend
class Claude(IntelligenceBackend):
"""Interface to the Claude offered by Anthropic."""

Expand Down
11 changes: 10 additions & 1 deletion chatarena/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import List
from typing import Dict, List, Type

from ..config import BackendConfig, Configurable
from ..message import Message
Expand Down Expand Up @@ -64,3 +64,12 @@ def reset(self):
raise NotImplementedError
else:
pass


BACKEND_REGISTRY: Dict[str, Type[IntelligenceBackend]] = {}


def register_backend(cls: Type[IntelligenceBackend]) -> Type[IntelligenceBackend]:
"""Register a new backend."""
BACKEND_REGISTRY[cls.type_name] = cls
return cls
3 changes: 2 additions & 1 deletion chatarena/backends/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from tenacity import retry, stop_after_attempt, wait_random_exponential

from ..message import Message
from .base import IntelligenceBackend
from .base import IntelligenceBackend, register_backend

# Try to import the cohere package and check whether the API key is set
try:
Expand All @@ -23,6 +23,7 @@
DEFAULT_MODEL = "command-xlarge"


@register_backend
class CohereAIChat(IntelligenceBackend):
"""Interface to the Cohere API."""

Expand Down
3 changes: 2 additions & 1 deletion chatarena/backends/hf_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..message import SYSTEM_NAME as SYSTEM
from ..message import Message
from .base import IntelligenceBackend
from .base import IntelligenceBackend, register_backend


@contextmanager
Expand All @@ -32,6 +32,7 @@ def suppress_stdout_stderr():
is_transformers_available = True


@register_backend
class TransformersConversational(IntelligenceBackend):
"""Interface to the Transformers ConversationalPipeline."""

Expand Down
3 changes: 2 additions & 1 deletion chatarena/backends/human.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ..config import BackendConfig
from .base import IntelligenceBackend
from .base import IntelligenceBackend, register_backend


# An Error class for the human backend
Expand All @@ -9,6 +9,7 @@ def __init__(self, agent_name: str):
super().__init__(f"Human backend requires a UI to get input from {agent_name}.")


@register_backend
class Human(IntelligenceBackend):
stateful = False
type_name = "human"
Expand Down
3 changes: 2 additions & 1 deletion chatarena/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tenacity import retry, stop_after_attempt, wait_random_exponential

from ..message import SYSTEM_NAME, Message
from .base import IntelligenceBackend
from .base import IntelligenceBackend, register_backend

try:
import openai
Expand All @@ -31,6 +31,7 @@
BASE_PROMPT = f"The messages always end with the token {END_OF_MESSAGE}."


@register_backend
class OpenAIChat(IntelligenceBackend):
"""Interface to the ChatGPT style model with system, user, assistant roles separation."""

Expand Down
12 changes: 1 addition & 11 deletions chatarena/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
from ..config import EnvironmentConfig
from .base import Environment, TimeStep
from .base import ENV_REGISTRY, Environment, TimeStep, register_env
from .chameleon import Chameleon
from .conversation import Conversation, ModeratedConversation
from .pettingzoo_chess import PettingzooChess
from .pettingzoo_tictactoe import PettingzooTicTacToe

ALL_ENVIRONMENTS = [
Conversation,
ModeratedConversation,
Chameleon,
PettingzooChess,
PettingzooTicTacToe,
]

ENV_REGISTRY = {env.type_name: env for env in ALL_ENVIRONMENTS}


# Load an environment from a config dictionary
def load_environment(config: EnvironmentConfig):
Expand Down
19 changes: 18 additions & 1 deletion chatarena/environments/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Type

from ..config import Configurable, EnvironmentConfig
from ..message import Message
Expand Down Expand Up @@ -185,3 +185,20 @@ def get_one_rewards(self) -> Dict[str, float]:
Dict[str, float]: A dictionary of players and their rewards (all one).
"""
return {player_name: 1.0 for player_name in self.player_names}


ENV_REGISTRY: Dict[str, Type[Environment]] = {}


def register_env(cls: Type[Environment]) -> Type[Environment]:
"""
Register an environment class.
Parameters:
cls (Type[Environment]): The class to register.
Returns:
Type[Environment]: The class that was registered.
"""
ENV_REGISTRY[cls.type_name] = cls
return cls
3 changes: 2 additions & 1 deletion chatarena/environments/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ..agent import SIGNAL_END_OF_CONVERSATION
from ..message import Message, MessagePool
from .base import Environment, TimeStep
from .base import Environment, TimeStep, register_env

DEFAULT_TOPIC_CODES = {
"Fruits": [
Expand Down Expand Up @@ -50,6 +50,7 @@
}


@register_env
class Chameleon(Environment):
type_name = "chameleon"

Expand Down
4 changes: 3 additions & 1 deletion chatarena/environments/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from ..agent import SIGNAL_END_OF_CONVERSATION, Moderator
from ..config import AgentConfig, EnvironmentConfig
from ..message import Message, MessagePool
from .base import Environment, TimeStep
from .base import Environment, TimeStep, register_env


@register_env
class Conversation(Environment):
"""
Turn-based fully observable conversation environment.
Expand Down Expand Up @@ -93,6 +94,7 @@ def step(self, player_name: str, action: str) -> TimeStep:
return timestep


@register_env
class ModeratedConversation(Conversation):
"""
Turn-based fully observable conversation environment.
Expand Down
3 changes: 2 additions & 1 deletion chatarena/environments/pettingzoo_chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pettingzoo.classic import chess_v6
from pettingzoo.classic.chess.chess_utils import chess, get_move_plane

from chatarena.environments.base import Environment, TimeStep
from chatarena.environments.base import Environment, TimeStep, register_env

from ..message import Message, MessagePool

Expand All @@ -27,6 +27,7 @@ def action_string_to_alphazero_format(action: str, player_index: int) -> int:
return x1 * 8 * 73 + y1 * 73 + move_plane


@register_env
class PettingzooChess(Environment):
type_name = "pettingzoo:chess"

Expand Down
3 changes: 2 additions & 1 deletion chatarena/environments/pettingzoo_tictactoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pettingzoo.classic import tictactoe_v3

from chatarena.environments.base import Environment, TimeStep
from chatarena.environments.base import Environment, TimeStep, register_env

from ..message import Message, MessagePool

Expand All @@ -27,6 +27,7 @@ def action_string_to_action(action: str) -> int:
return row + column * 3


@register_env
class PettingzooTicTacToe(Environment):
type_name = "pettingzoo:tictactoe"

Expand Down
91 changes: 82 additions & 9 deletions tests/unit/test_environments.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,99 @@
import unittest
from unittest import TestCase

from chatarena.environments import PettingzooTicTacToe
from chatarena.config import AgentConfig, BackendConfig, EnvironmentConfig
from chatarena.environments import (
Chameleon,
Environment,
ModeratedConversation,
PettingzooChess,
PettingzooTicTacToe,
load_environment,
register_env,
)


class TestEnvironments(TestCase):
def test_chess_environment(self):
player_names = ["player1", "player2"]
env = PettingzooTicTacToe(player_names)
def test_env_registration(self):
@register_env
class TestEnv(Environment):
type_name = "test"

@classmethod
def from_config(cls, config: EnvironmentConfig):
return cls(player_names=config.player_names)

env_config = EnvironmentConfig(
env_type="test", player_names=["player1", "player2"]
)
env = load_environment(env_config)
assert isinstance(env, TestEnv)


class TestTicTacToeEnvironment(TestCase):
def config(self):
return EnvironmentConfig(
env_type="pettingzoo:tictactoe", player_names=["player1", "player2"]
)

def test_registration_and_loading(self):
env = load_environment(self.config())
assert isinstance(env, PettingzooTicTacToe)

def test_game(self):
env = load_environment(self.config())
env.reset()
assert env.get_next_player() == "player1"
env.print()

moves = ["X: (3, 1)", "O: (2, 2)", "X: (1, 2)", "O: (1, 1)"]

for i, move in enumerate(moves):
assert env.check_action(move, env.get_next_player())
timestep = env.step(env.get_next_player(), move)
print(timestep.reward)
print(timestep.terminal)
env.print()
env.step(env.get_next_player(), move)
assert not env.is_terminal()


class TestChameleonEnvironment(TestCase):
def test_registration_and_loading(self):
config = EnvironmentConfig(
env_type="chameleon", player_names=["player1", "player2"]
)
env = load_environment(config)
assert isinstance(env, Chameleon)


class TestConversationEnvironment(TestCase):
def test_registration_and_loading(self):
config = EnvironmentConfig(
env_type="conversation", player_names=["player1", "player2"]
)
env = load_environment(config)
assert isinstance(env, Environment)


class TestModeratedConversationEnvironment(TestCase):
def test_registration_and_loading(self):
moderator = AgentConfig(
role_desc="moderator",
backend=BackendConfig(backend_type="human"),
terminal_condition="all_done",
)
config = EnvironmentConfig(
env_type="moderated_conversation",
player_names=["player1", "player2"],
moderator=moderator,
)
env = load_environment(config)
assert isinstance(env, ModeratedConversation)


class TestPettingzooChessEnvironment(TestCase):
def test_registration_and_loading(self):
config = EnvironmentConfig(
env_type="pettingzoo:chess", player_names=["player1", "player2"]
)
env = load_environment(config)
assert isinstance(env, PettingzooChess)


if __name__ == "__main__":
Expand Down

0 comments on commit 2dcaa08

Please sign in to comment.