Skip to content

Commit 6a273e0

Browse files
committed
Fix reanalyse and format
1 parent 07ef102 commit 6a273e0

15 files changed

+45
-25
lines changed

games/atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self):
1818
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1919

2020
self.seed = 0 # Seed for numpy, torch and the game
21-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
21+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
2222

2323

2424

games/breakout.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self):
1818
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1919

2020
self.seed = 0 # Seed for numpy, torch and the game
21-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
21+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
2222

2323

2424

games/cartpole.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self):
1313
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1414

1515
self.seed = 0 # Seed for numpy, torch and the game
16-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
16+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
1717

1818

1919

games/connect4.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self):
1212
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1313

1414
self.seed = 0 # Seed for numpy, torch and the game
15-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
15+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
1616

1717

1818

games/gomoku.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self):
1313
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1414

1515
self.seed = 0 # Seed for numpy, torch and the game
16-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
16+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
1717

1818

1919

games/gridworld.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self):
1818
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1919

2020
self.seed = 0 # Seed for numpy, torch and the game
21-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
21+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
2222

2323

2424

games/lunarlander.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def __init__(self):
1313
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1414

1515
self.seed = 0 # Seed for numpy, torch and the game
16-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
16+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
1717

1818

1919

games/simple_grid.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self):
1212
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1313

1414
self.seed = 0 # Seed for numpy, torch and the game
15-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
15+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
1616

1717

1818

games/tictactoe.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self):
1212
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
1313

1414
self.seed = 0 # Seed for numpy, torch and the game
15-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
15+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
1616

1717

1818

games/twentyone.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self):
1919
# More information is available here: https://github.com/werner-duvaud/muzero-general/wiki/Hyperparameter-Optimization
2020

2121
self.seed = 0 # Seed for numpy, torch and the game
22-
self.max_num_gpus = None # Fix the maximum number of GPUs to use. By default muzero uses every GPUs available
22+
self.max_num_gpus = None # Fix the maximum number of GPUs to use. It's usually faster to use a single GPU (set it to 1) if it has enough memory. None will use every GPUs available
2323

2424

2525

muzero.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,7 @@ def hyperparameter_search(
633633
parallel_experiments = 2
634634
lr_init = nevergrad.p.Log(a_min=0.0001, a_max=0.1)
635635
discount = nevergrad.p.Log(lower=0.95, upper=0.9999)
636-
parametrization = nevergrad.p.Dict(
637-
lr_init=lr_init, discount=discount
638-
)
636+
parametrization = nevergrad.p.Dict(lr_init=lr_init, discount=discount)
639637
best_hyperparameters = hyperparameter_search(
640638
game_name, parametrization, budget, parallel_experiments, 20
641639
)

replay_buffer.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ class ReplayBuffer:
1717
def __init__(self, initial_checkpoint, initial_buffer, config):
1818
self.config = config
1919
self.buffer = copy.deepcopy(initial_buffer)
20-
self.num_played_games = initial_checkpoint['num_played_games']
21-
self.num_played_steps = initial_checkpoint['num_played_steps']
22-
self.total_samples = sum([len(game_history.root_values) for game_history in self.buffer.values()])
20+
self.num_played_games = initial_checkpoint["num_played_games"]
21+
self.num_played_steps = initial_checkpoint["num_played_steps"]
22+
self.total_samples = sum(
23+
[len(game_history.root_values) for game_history in self.buffer.values()]
24+
)
2325
if self.total_samples != 0:
24-
print(f"Replay buffer initialized with {self.total_samples} samples ({self.num_played_games} games).\n")
26+
print(
27+
f"Replay buffer initialized with {self.total_samples} samples ({self.num_played_games} games).\n"
28+
)
2529

2630
# Fix random generator seed
2731
numpy.random.seed(self.config.seed)
@@ -203,11 +207,17 @@ def compute_target_value(self, game_history, index):
203207
# future, plus the discounted sum of all rewards until then.
204208
bootstrap_index = index + self.config.td_steps
205209
if bootstrap_index < len(game_history.root_values):
210+
root_values = (
211+
game_history.root_values
212+
if game_history.reanalysed_predicted_root_values is None
213+
else game_history.reanalysed_predicted_root_values
214+
)
215+
print(game_history.reanalysed_predicted_root_values is None)
206216
last_step_value = (
207-
game_history.root_values[bootstrap_index]
217+
root_values[bootstrap_index]
208218
if game_history.to_play_history[bootstrap_index]
209219
== game_history.to_play_history[index]
210-
else -game_history.root_values[bootstrap_index]
220+
else -root_values[bootstrap_index]
211221
)
212222

213223
value = last_step_value * self.config.discount ** self.config.td_steps
@@ -323,8 +333,9 @@ def reanalyse(self, replay_buffer, shared_storage):
323333
self.model.initial_inference(observations)[0],
324334
self.config.support_size,
325335
)
326-
for i in range(len(game_history.root_values)):
327-
game_history.root_values[i] = values[i].item()
336+
game_history.reanalysed_predicted_root_values = (
337+
torch.squeeze(values).detach().numpy()
338+
)
328339

329340
replay_buffer.update_game_history.remote(game_id, game_history)
330341
self.num_reanalysed_games += 1

results/cartpole/model.checkpoint

34 KB
Binary file not shown.

self_play.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def play_game(
136136
numpy.array(observation).shape == self.config.observation_shape
137137
), f"Observation should match the observation_shape defined in MuZeroConfig. Expected {self.config.observation_shape} but got {numpy.array(observation).shape}."
138138
stacked_observations = game_history.get_stacked_observations(
139-
-1, self.config.stacked_observations,
139+
-1,
140+
self.config.stacked_observations,
140141
)
141142

142143
# Choose the action
@@ -223,7 +224,7 @@ def select_opponent_action(self, opponent, stacked_observations):
223224
def select_action(node, temperature):
224225
"""
225226
Select action according to the visit count distribution and the temperature.
226-
The temperature is changed dynamically with the visit_softmax_temperature function
227+
The temperature is changed dynamically with the visit_softmax_temperature function
227228
in the config.
228229
"""
229230
visit_counts = numpy.array(
@@ -300,7 +301,11 @@ def run(
300301
set(self.config.action_space)
301302
), "Legal actions should be a subset of the action space."
302303
root.expand(
303-
legal_actions, to_play, reward, policy_logits, hidden_state,
304+
legal_actions,
305+
to_play,
306+
reward,
307+
policy_logits,
308+
hidden_state,
304309
)
305310

306311
if add_exploration_noise:
@@ -484,6 +489,7 @@ def __init__(self):
484489
self.to_play_history = []
485490
self.child_visits = []
486491
self.root_values = []
492+
self.reanalysed_predicted_root_values = None
487493
# For PER
488494
self.priorities = None
489495
self.game_priority = None

trainer.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,12 @@ def update_lr(self):
280280

281281
@staticmethod
282282
def loss_function(
283-
value, reward, policy_logits, target_value, target_reward, target_policy,
283+
value,
284+
reward,
285+
policy_logits,
286+
target_value,
287+
target_reward,
288+
target_policy,
284289
):
285290
# Cross-entropy seems to have a better convergence than MSE
286291
value_loss = (-target_value * torch.nn.LogSoftmax(dim=1)(value)).sum(1)

0 commit comments

Comments
 (0)