Skip to content

Commit

Permalink
Add environment loading tests
Browse files Browse the repository at this point in the history
  • Loading branch information
edmundmills committed Dec 15, 2023
1 parent 6aef1b1 commit 8ae500e
Showing 1 changed file with 72 additions and 13 deletions.
85 changes: 72 additions & 13 deletions tests/unit/test_environments.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,99 @@
import unittest
from unittest import TestCase

from chatarena.config import EnvironmentConfig
from chatarena.environments import PettingzooTicTacToe, load_environment, register_env
from chatarena.config import AgentConfig, BackendConfig, EnvironmentConfig
from chatarena.environments import (
Chameleon,
Environment,
ModeratedConversation,
PettingzooChess,
PettingzooTicTacToe,
load_environment,
register_env,
)


class TestEnvironments(TestCase):
def test_env_registration(self):
@register_env
class TestEnv:
class TestEnv(Environment):
type_name = "test"

@classmethod
def from_config(cls, config: EnvironmentConfig):
return cls()
return cls(player_names=config.player_names)

env_config = EnvironmentConfig(env_type="test")
env_config = EnvironmentConfig(
env_type="test", player_names=["player1", "player2"]
)
env = load_environment(env_config)
assert isinstance(env, TestEnv)

def test_chess_environment(self):
player_names = ["player1", "player2"]
env = PettingzooTicTacToe(player_names)

class TestTicTacToeEnvironment(TestCase):
def config(self):
return EnvironmentConfig(
env_type="pettingzoo:tictactoe", player_names=["player1", "player2"]
)

def test_registration_and_loading(self):
env = load_environment(self.config())
assert isinstance(env, PettingzooTicTacToe)

def test_game(self):
env = load_environment(self.config())
env.reset()
assert env.get_next_player() == "player1"
env.print()

moves = ["X: (3, 1)", "O: (2, 2)", "X: (1, 2)", "O: (1, 1)"]

for i, move in enumerate(moves):
assert env.check_action(move, env.get_next_player())
timestep = env.step(env.get_next_player(), move)
print(timestep.reward)
print(timestep.terminal)
env.print()
env.step(env.get_next_player(), move)
assert not env.is_terminal()


class TestChameleonEnvironment(TestCase):
def test_registration_and_loading(self):
config = EnvironmentConfig(
env_type="chameleon", player_names=["player1", "player2"]
)
env = load_environment(config)
assert isinstance(env, Chameleon)


class TestConversationEnvironment(TestCase):
def test_registration_and_loading(self):
config = EnvironmentConfig(
env_type="conversation", player_names=["player1", "player2"]
)
env = load_environment(config)
assert isinstance(env, Environment)


class TestModeratedConversationEnvironment(TestCase):
def test_registration_and_loading(self):
moderator = AgentConfig(
role_desc="moderator",
backend=BackendConfig(backend_type="human"),
terminal_condition="all_done",
)
config = EnvironmentConfig(
env_type="moderated_conversation",
player_names=["player1", "player2"],
moderator=moderator,
)
env = load_environment(config)
assert isinstance(env, ModeratedConversation)


class TestPettingzooChessEnvironment(TestCase):
def test_registration_and_loading(self):
config = EnvironmentConfig(
env_type="pettingzoo:chess", player_names=["player1", "player2"]
)
env = load_environment(config)
assert isinstance(env, PettingzooChess)


if __name__ == "__main__":
Expand Down

0 comments on commit 8ae500e

Please sign in to comment.