forked from werner-duvaud/muzero-general
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add open_spiel game wrapper (werner-duvaud#103)
* Add Spiel Game wrapper * Raise error if open_spiel not installed * Add a little documentation
- Loading branch information
Showing
1 changed file
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
import datetime | ||
import os | ||
|
||
import numpy | ||
import torch | ||
|
||
from .abstract_game import AbstractGame | ||
|
||
|
||
# This is a Game wrapper for open_spiel games. It allows you to run any game in the open_spiel library. | ||
|
||
try: | ||
import pyspiel | ||
|
||
except ImportError: | ||
import sys | ||
sys.exit("You need to install open_spiel by running pip install open_spiel. For a full documentation, see: https://github.com/deepmind/open_spiel/blob/master/docs/install.md") | ||
|
||
# The game you want to run. See https://github.com/deepmind/open_spiel/blob/master/docs/games.md for a list of games | ||
game = pyspiel.load_game("tic_tac_toe") | ||
|
||
class MuZeroConfig: | ||
def __init__(self): | ||
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization | ||
|
||
self.game = game | ||
|
||
self.seed = 0 # Seed for numpy, torch and the game | ||
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available | ||
|
||
|
||
|
||
### Game | ||
self.observation_shape = tuple(self.game.observation_tensor_shape()) # Dimensions of the game observation, must be 3D (channel, height, width). For a 1D array, please reshape it to (1, 1, length of array) | ||
self.action_space = list(range(self.game.policy_tensor_shape()[0])) # Fixed list of all possible actions. You should only edit the length | ||
self.players = list(range(self.game.num_players())) # List of players. You should only edit the length | ||
self.stacked_observations = 0 # Number of previous observations and previous actions to add to the current observation | ||
|
||
# Evaluate | ||
self.muzero_player = 0 # Turn Muzero begins to play (0: MuZero plays first, 1: MuZero plays second) | ||
self.opponent = "self" # Hard coded agent that MuZero faces to assess his progress in multiplayer games. It doesn't influence training. None, "random" or "expert" if implemented in the Game class | ||
|
||
|
||
|
||
### Self-Play | ||
self.num_workers = 1 # Number of simultaneous threads/workers self-playing to feed the replay buffer | ||
self.selfplay_on_gpu = False | ||
self.max_moves = self.game.max_game_length() # Maximum number of moves if game is not finished before | ||
self.num_simulations = 25 # Number of future moves self-simulated | ||
self.discount = 0.1 # Chronological discount of the reward | ||
self.temperature_threshold = None # Number of moves before dropping the temperature given by visit_softmax_temperature_fn to 0 (ie selecting the best action). If None, visit_softmax_temperature_fn is used every time | ||
|
||
# Root prior exploration noise | ||
self.root_dirichlet_alpha = 0.1 | ||
self.root_exploration_fraction = 0.25 | ||
|
||
# UCB formula | ||
self.pb_c_base = 19652 | ||
self.pb_c_init = 1.25 | ||
|
||
|
||
|
||
### Network | ||
self.network = "resnet" # "resnet" / "fullyconnected" | ||
self.support_size = 10 # Value and reward are scaled (with almost sqrt) and encoded on a vector with a range of -support_size to support_size. Choose it so that support_size <= sqrt(max(abs(discounted reward))) | ||
|
||
# Residual Network | ||
self.downsample = False # Downsample observations before representation network, False / "CNN" (lighter) / "resnet" (See paper appendix Network Architecture) | ||
self.blocks = 2 # Number of blocks in the ResNet | ||
self.channels = 16 # Number of channels in the ResNet | ||
self.reduced_channels_reward = 16 # Number of channels in reward head | ||
self.reduced_channels_value = 16 # Number of channels in value head | ||
self.reduced_channels_policy = 16 # Number of channels in policy head | ||
self.resnet_fc_reward_layers = [8] # Define the hidden layers in the reward head of the dynamic network | ||
self.resnet_fc_value_layers = [8] # Define the hidden layers in the value head of the prediction network | ||
self.resnet_fc_policy_layers = [8] # Define the hidden layers in the policy head of the prediction network | ||
|
||
# Fully Connected Network | ||
self.encoding_size = 32 | ||
self.fc_representation_layers = [] # Define the hidden layers in the representation network | ||
self.fc_dynamics_layers = [16] # Define the hidden layers in the dynamics network | ||
self.fc_reward_layers = [16] # Define the hidden layers in the reward network | ||
self.fc_value_layers = [] # Define the hidden layers in the value network | ||
self.fc_policy_layers = [] # Define the hidden layers in the policy network | ||
|
||
|
||
|
||
### Training | ||
self.results_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../results", os.path.basename(__file__)[:-3], datetime.datetime.now().strftime("%Y-%m-%d--%H-%M-%S")) # Path to store the model weights and TensorBoard logs | ||
self.save_model = True # Save the checkpoint in results_path as model.checkpoint | ||
self.training_steps = 1000000 # Total number of training steps (ie weights update according to a batch) | ||
self.batch_size = 64 # Number of parts of games to train on at each training step | ||
self.checkpoint_interval = 10 # Number of training steps before using the model for self-playing | ||
self.value_loss_weight = 0.25 # Scale the value loss to avoid overfitting of the value function, paper recommends 0.25 (See paper appendix Reanalyze) | ||
self.train_on_gpu = torch.cuda.is_available() # Train on GPU if available | ||
|
||
self.optimizer = "Adam" # "Adam" or "SGD". Paper uses SGD | ||
self.weight_decay = 1e-4 # L2 weights regularization | ||
self.momentum = 0.9 # Used only if optimizer is SGD | ||
|
||
# Exponential learning rate schedule | ||
self.lr_init = 0.003 # Initial learning rate | ||
self.lr_decay_rate = 1 # Set it to 1 to use a constant learning rate | ||
self.lr_decay_steps = 10000 | ||
|
||
|
||
|
||
### Replay Buffer | ||
self.replay_buffer_size = 3000 # Number of self-play games to keep in the replay buffer | ||
self.num_unroll_steps = 20 # Number of game moves to keep for every batch element | ||
self.td_steps = 20 # Number of steps in the future to take into account for calculating the target value | ||
self.PER = True # Prioritized Replay (See paper appendix Training), select in priority the elements in the replay buffer which are unexpected for the network | ||
self.PER_alpha = 0.5 # How much prioritization is used, 0 corresponding to the uniform case, paper suggests 1 | ||
|
||
# Reanalyze (See paper appendix Reanalyse) | ||
self.use_last_model_value = True # Use the last model to provide a fresher, stable n-step value (See paper appendix Reanalyze) | ||
self.reanalyse_on_gpu = False | ||
|
||
|
||
|
||
### Adjust the self play / training ratio to avoid over/underfitting | ||
self.self_play_delay = 0 # Number of seconds to wait after each played game | ||
self.training_delay = 0 # Number of seconds to wait after each training step | ||
self.ratio = None # Desired training steps per self played step ratio. Equivalent to a synchronous version, training can take much longer. Set it to None to disable it | ||
|
||
|
||
def visit_softmax_temperature_fn(self, trained_steps): | ||
""" | ||
Parameter to alter the visit count distribution to ensure that the action selection becomes greedier as training progresses. | ||
The smaller it is, the more likely the best action (ie with the highest visit count) is chosen. | ||
Returns: | ||
Positive float. | ||
""" | ||
return 1 | ||
|
||
|
||
class Game(AbstractGame): | ||
""" | ||
Game wrapper. | ||
""" | ||
|
||
def __init__(self, seed=None): | ||
self.env = Spiel() | ||
|
||
def step(self, action): | ||
""" | ||
Apply action to the game. | ||
Args: | ||
action : action of the action_space to take. | ||
Returns: | ||
The new observation, the reward and a boolean if the game has ended. | ||
""" | ||
observation, reward, done = self.env.step(action) | ||
return observation, reward * 20, done | ||
|
||
def to_play(self): | ||
""" | ||
Return the current player. | ||
Returns: | ||
The current player, it should be an element of the players list in the config. | ||
""" | ||
return self.env.to_play() | ||
|
||
def legal_actions(self): | ||
""" | ||
Should return the legal actions at each turn, if it is not available, it can return | ||
the whole action space. At each turn, the game have to be able to handle one of returned actions. | ||
For complex game where calculating legal moves is too long, the idea is to define the legal actions | ||
equal to the action space but to return a negative reward if the action is illegal. | ||
Returns: | ||
An array of integers, subset of the action space. | ||
""" | ||
return self.env.legal_actions() | ||
|
||
def reset(self): | ||
""" | ||
Reset the game for a new game. | ||
Returns: | ||
Initial observation of the game. | ||
""" | ||
return self.env.reset() | ||
|
||
def render(self): | ||
""" | ||
Display the game observation. | ||
""" | ||
self.env.render() | ||
input("Press enter to take a step ") | ||
|
||
def legal_actions_human(self): | ||
return self.env.human_legal_actions() | ||
|
||
def human_to_action(self): | ||
""" | ||
For multiplayer games, ask the user for a legal action | ||
and return the corresponding action number. | ||
Returns: | ||
An integer from the action space. | ||
""" | ||
while True: | ||
try: | ||
print("Legal Actions: ", self.legal_actions_human()) | ||
choice = input("Enter your move: ") | ||
if (choice in self.legal_actions_human()): | ||
break | ||
except: | ||
pass | ||
print("Wrong input, try again") | ||
|
||
return self.env.board.string_to_action(choice) | ||
|
||
|
||
def action_to_string(self, action_number): | ||
""" | ||
Convert an action number to a string representing the action. | ||
Args: | ||
action_number: an integer from the action space. | ||
Returns: | ||
String representing the action. | ||
""" | ||
row = action_number // 3 + 1 | ||
col = action_number % 3 + 1 | ||
return f"Play row {row}, column {col}" | ||
|
||
|
||
class Spiel: | ||
def __init__(self): | ||
self.game = game | ||
self.board = self.game.new_initial_state() | ||
self.player = 1 | ||
|
||
def to_play(self): | ||
return 0 if self.player == 1 else 1 | ||
|
||
def reset(self): | ||
self.board = self.game.new_initial_state() | ||
self.player = 1 | ||
return self.get_observation() | ||
|
||
def step(self, action): | ||
self.board = self.board.child(action) | ||
|
||
done = self.board.is_terminal() | ||
|
||
reward = 1 if self.have_winner() else 0 | ||
|
||
observation = self.get_observation() | ||
|
||
self.player *= -1 | ||
|
||
return observation, reward, done | ||
|
||
def get_observation(self): | ||
if self.player == 1: | ||
current_player = 1 | ||
else: | ||
current_player = 0 | ||
return numpy.array(self.board.observation_tensor(current_player)).reshape(self.game.observation_tensor_shape()) | ||
|
||
def legal_actions(self): | ||
return self.board.legal_actions() | ||
|
||
def have_winner(self): | ||
rewards = self.board.rewards() | ||
|
||
if (self.player == 1): | ||
|
||
if (rewards[0] == 1.0): | ||
return True | ||
|
||
elif (self.player == -1): | ||
if (rewards[1] == 1.0): | ||
return True | ||
|
||
return False | ||
|
||
def human_legal_actions(self): | ||
return [self.board.action_to_string(x) for x in self.board.legal_actions()] | ||
|
||
def render(self): | ||
print(self.board) |