-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
496 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,4 +137,7 @@ dmypy.json | |
# Pyre type checker | ||
.pyre/ | ||
|
||
/db-data | ||
/db-data | ||
|
||
# wandb | ||
/wandb/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
51 changes: 51 additions & 0 deletions
51
catanatron_experimental/catanatron_experimental/machine_learning/custom_cnn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch as th | ||
from torch import nn | ||
import gymnasium as gym | ||
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor | ||
|
||
|
||
class CustomCNN(BaseFeaturesExtractor): | ||
""" | ||
Custom CNN to process the board observations. | ||
:param observation_space: (gym.Space) | ||
:param cnn_arch: List of integers specifying the number of filters in each Conv layer. | ||
:param features_dim: (int) Number of features extracted. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
observation_space: gym.spaces.Dict, | ||
cnn_arch, | ||
features_dim: int = 256, | ||
): | ||
super(CustomCNN, self).__init__(observation_space, features_dim) | ||
n_input_channels = observation_space["board"].shape[0] | ||
|
||
layers = [] | ||
in_channels = n_input_channels | ||
for out_channels in cnn_arch: | ||
layers.append( | ||
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) | ||
) | ||
layers.append(nn.BatchNorm2d(out_channels)) | ||
layers.append(nn.ReLU()) | ||
in_channels = out_channels | ||
layers.append(nn.Flatten()) | ||
self.cnn = nn.Sequential(*layers) | ||
|
||
# Compute the number of features after CNN | ||
with th.no_grad(): | ||
sample_board = th.as_tensor( | ||
observation_space.sample()["board"][None] | ||
).float() | ||
n_flatten = self.cnn(sample_board).shape[1] | ||
|
||
n_numeric_features = observation_space["numeric"].shape[0] | ||
self.linear = nn.Sequential( | ||
nn.Linear(n_flatten + n_numeric_features, features_dim), nn.ReLU() | ||
) | ||
|
||
def forward(self, observations: dict) -> th.Tensor: | ||
board_features = self.cnn(observations["board"]) | ||
concatenated_tensor = th.cat([board_features, observations["numeric"]], dim=1) | ||
return self.linear(concatenated_tensor) |
119 changes: 119 additions & 0 deletions
119
catanatron_experimental/catanatron_experimental/machine_learning/players/ppo.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
from typing import Iterable | ||
import numpy as np | ||
import os | ||
from sb3_contrib import MaskablePPO | ||
|
||
from catanatron.game import Game | ||
from catanatron.models.actions import Action | ||
from catanatron.models.player import Player | ||
from catanatron_gym.envs.catanatron_env import from_action_space, to_action_space | ||
from catanatron_gym.features import create_sample, get_feature_ordering | ||
from catanatron_gym.board_tensor_features import ( | ||
create_board_tensor, | ||
is_graph_feature, | ||
) | ||
from catanatron_experimental.machine_learning.custom_cnn import CustomCNN | ||
|
||
|
||
class PPOPlayer(Player): | ||
""" | ||
Proximal Policy Optimization (PPO) reinforcement learning agent. | ||
""" | ||
|
||
def __init__(self, color, model_path=None): | ||
super().__init__(color) | ||
self.model = None | ||
self.numeric_features = None | ||
if model_path is None: | ||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
model_path = os.path.join(script_dir, "..", "model.zip") | ||
if model_path: | ||
self.load(model_path) | ||
|
||
def decide(self, game: Game, playable_actions: Iterable[Action]): | ||
if self.model is None: | ||
raise ValueError("Model not loaded. Call load() first.") | ||
|
||
# Initialize numeric_features based on the current game | ||
if self.numeric_features is None: | ||
num_players = len(game.state.players) | ||
self.features = get_feature_ordering(num_players) | ||
self.numeric_features = [ | ||
f for f in self.features if not is_graph_feature(f) | ||
] | ||
|
||
# Generate observation from the game state | ||
obs = self.generate_observation(game) | ||
|
||
# Generate action mask from playable actions | ||
action_mask = self.generate_action_mask(playable_actions) | ||
|
||
# Predict the action index | ||
action_index, _ = self.model.predict( | ||
obs, action_masks=action_mask, deterministic=True | ||
) | ||
|
||
# Map the action index to the actual Action | ||
try: | ||
selected_action = self.action_index_to_action( | ||
action_index, playable_actions | ||
) | ||
return selected_action | ||
except Exception as e: | ||
print(f"Error mapping action index to Action: {e}") | ||
# Default to the first playable action | ||
return list(playable_actions)[0] | ||
|
||
def generate_observation(self, game: Game): | ||
# Create the sample | ||
sample = create_sample(game, self.color) | ||
|
||
# Generate board tensor | ||
board_tensor = create_board_tensor( | ||
game, self.color, channels_first=True | ||
).astype(np.float32) | ||
|
||
# Extract numeric features | ||
numeric = np.array( | ||
[float(sample[i]) for i in self.numeric_features], dtype=np.float32 | ||
) | ||
|
||
# Create the observation | ||
obs = {"board": board_tensor, "numeric": numeric} | ||
return obs | ||
|
||
def generate_action_mask(self, playable_actions: Iterable[Action]): | ||
action_mask = np.zeros(self.model.action_space.n, dtype=bool) | ||
for action in playable_actions: | ||
try: | ||
action_index = self.action_to_action_index(action) | ||
if ( | ||
action_index is not None | ||
and 0 <= action_index < self.model.action_space.n | ||
): | ||
action_mask[action_index] = True | ||
except Exception as e: | ||
print(f"Error in action_to_action_index: {e}") | ||
continue | ||
return action_mask | ||
|
||
def action_to_action_index(self, action: Action): | ||
action_index = to_action_space(action) | ||
return action_index | ||
|
||
def action_index_to_action( | ||
self, action_index: int, playable_actions: Iterable[Action] | ||
): | ||
action = from_action_space(action_index, playable_actions) | ||
if action in playable_actions: | ||
return action | ||
else: | ||
raise ValueError(f"Action {action} not in playable actions.") | ||
|
||
def load(self, path): | ||
self.model = MaskablePPO.load( | ||
path, | ||
custom_objects={ | ||
"features_extractor_class": CustomCNN, | ||
}, | ||
) |
55 changes: 55 additions & 0 deletions
55
catanatron_experimental/catanatron_experimental/machine_learning/reward_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# reward_functions.py | ||
|
||
import numpy as np | ||
from catanatron.state_functions import get_actual_victory_points | ||
|
||
|
||
def partial_rewards(game, p0_color, vps_to_win): | ||
""" | ||
Calculate the partial rewards for the game. | ||
Args: | ||
game: The game instance. | ||
p0_color: The color representing the player's position. | ||
vps_to_win: The victory points required to win the game. | ||
Returns: | ||
A float representing the partial reward. | ||
""" | ||
winning_color = game.winning_color() | ||
if winning_color is None: | ||
return 0 | ||
|
||
total = 0 | ||
if p0_color == winning_color: | ||
total += 0.20 | ||
else: | ||
total -= 0.20 | ||
enemy_vps = [ | ||
get_actual_victory_points(game.state, color) | ||
for color in game.state.colors | ||
if color != p0_color | ||
] | ||
enemy_avg_vp = sum(enemy_vps) / len(enemy_vps) | ||
my_vps = get_actual_victory_points(game.state, p0_color) | ||
vp_diff = (my_vps - enemy_avg_vp) / (vps_to_win - 1) | ||
|
||
total += 0.80 * vp_diff | ||
print(f"my_vps = {my_vps} enemy_avg_vp = {enemy_avg_vp} partial_rewards = {total}") | ||
return total | ||
|
||
|
||
def mask_fn(env) -> np.ndarray: | ||
""" | ||
Generates a boolean mask of valid actions for the environment. | ||
Args: | ||
env: The environment instance. | ||
Returns: | ||
A numpy array of booleans indicating valid actions. | ||
""" | ||
valid_actions = env.unwrapped.get_valid_actions() | ||
mask = np.zeros(env.action_space.n, dtype=bool) | ||
mask[valid_actions] = True | ||
return mask |
Oops, something went wrong.