diff --git a/rl_games/common/a2c_common.py b/rl_games/common/a2c_common.py index 30383c29..f9bd5a14 100644 --- a/rl_games/common/a2c_common.py +++ b/rl_games/common/a2c_common.py @@ -236,7 +236,6 @@ def __init__(self, base_name, params): self.game_shaped_rewards = torch_ext.AverageMeter(self.value_size, self.games_to_track).to(self.ppo_device) self.game_lengths = torch_ext.AverageMeter(1, self.games_to_track).to(self.ppo_device) self.obs = None - self.games_num = self.config['minibatch_size'] // self.seq_length # it is used only for current rnn implementation self.batch_size = self.horizon_length * self.num_actors * self.num_agents self.batch_size_envs = self.horizon_length * self.num_actors @@ -245,6 +244,16 @@ def __init__(self, base_name, params): self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) + # either minibatch_size_per_env or minibatch_size should be present in a config + # if both are present, minibatch_size is used + # otherwise minibatch_size_per_env is used minibatch_size_per_env is used to calculate minibatch_size + self.minibatch_size_per_env = self.config.get('minibatch_size_per_env', 0) + self.minibatch_size = self.config.get('minibatch_size', self.num_actors * self.minibatch_size_per_env) + + assert(self.minibatch_size > 0) + + self.games_num = self.minibatch_size // self.seq_length # it is used only for current rnn implementation + self.num_minibatches = self.batch_size // self.minibatch_size assert(self.batch_size % self.minibatch_size == 0)