Skip to content

Commit

Permalink
Merge pull request werner-duvaud#117 from mokemokechicken/feature/sam…
Browse files Browse the repository at this point in the history
…ple_n_games_at_one_time_in_get_batch

sample N games at one time in replay_buffer
  • Loading branch information
ahainaut authored Feb 9, 2021
2 parents 1fe969c + 9e957f5 commit 97e4931
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def get_batch(self):
) = ([], [], [], [], [], [], [])
weight_batch = [] if self.config.PER else None

for _ in range(self.config.batch_size):
game_id, game_history, game_prob = self.sample_game()
for game_id, game_history, game_prob in self.sample_n_games(self.config.batch_size):
game_pos, pos_prob = self.sample_position(game_history)

values, rewards, policies, actions = self.make_target(
Expand Down Expand Up @@ -154,6 +153,24 @@ def sample_game(self, force_uniform=False):

return game_id, self.buffer[game_id], game_prob

def sample_n_games(self, n_games, force_uniform=False):
if self.config.PER and not force_uniform:
game_id_list = []
game_probs = []
for game_id, game_history in self.buffer.items():
game_id_list.append(game_id)
game_probs.append(game_history.game_priority)
game_probs = numpy.array(game_probs, dtype="float32")
game_probs /= numpy.sum(game_probs)
game_prob_dict = dict([(game_id, prob) for game_id, prob in zip(game_id_list, game_probs)])
selected_games = numpy.random.choice(game_id_list, n_games, p=game_probs)
else:
selected_games = numpy.random.choice(list(self.buffer.keys()), n_games)
game_prob_dict = {}
ret = [(game_id, self.buffer[game_id], game_prob_dict.get(game_id))
for game_id in selected_games]
return ret

def sample_position(self, game_history, force_uniform=False):
"""
Sample position from game either uniformly or according to some priority.
Expand Down

0 comments on commit 97e4931

Please sign in to comment.