Skip to content

Commit

Permalink
Fix order of logging and training (DLR-RM#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiabir committed Oct 6, 2023
1 parent 574307a commit 43dc8f9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
10 changes: 5 additions & 5 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,6 @@ def learn(
if rollout.continue_training is False:
break

iteration += 1

if log_interval is not None and iteration % log_interval == 0:
self._dump_logs()

if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
# If no `gradient_steps` is specified,
# do as many gradients steps as steps performed during the rollout
Expand All @@ -347,6 +342,11 @@ def learn(
if gradient_steps > 0:
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)

iteration += 1

if log_interval is not None and iteration % log_interval == 0:
self._dump_logs()

callback.on_training_end()

return self
Expand Down
7 changes: 4 additions & 3 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,12 @@ def learn(
if continue_training is False:
break

iteration += 1
self._update_current_progress_remaining(self.num_timesteps, total_timesteps)

self.train()

iteration += 1

# Display training infos
if log_interval is not None and iteration % log_interval == 0:
assert self.ep_info_buffer is not None
Expand All @@ -285,8 +288,6 @@ def learn(
self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")
self.logger.dump(step=self.num_timesteps)

self.train()

callback.on_training_end()

return self
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_ppo_warnings():
# in that case
with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"):
model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1)
model.learn(64)
model.learn(64, log_interval=2)

loss = model.logger.name_to_value["train/loss"]
assert loss > 0
Expand Down

0 comments on commit 43dc8f9

Please sign in to comment.