From 43dc8f9d27b9f61839094fc50a004668a98848cd Mon Sep 17 00:00:00 2001 From: Tobias Birchler Date: Wed, 4 Oct 2023 23:07:08 +0200 Subject: [PATCH] Fix order of logging and training (#1708) --- stable_baselines3/common/off_policy_algorithm.py | 10 +++++----- stable_baselines3/common/on_policy_algorithm.py | 7 ++++--- tests/test_run.py | 2 +- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 76e277ecc..04325e6e0 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -334,11 +334,6 @@ def learn( if rollout.continue_training is False: break - iteration += 1 - - if log_interval is not None and iteration % log_interval == 0: - self._dump_logs() - if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts: # If no `gradient_steps` is specified, # do as many gradients steps as steps performed during the rollout @@ -347,6 +342,11 @@ def learn( if gradient_steps > 0: self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) + iteration += 1 + + if log_interval is not None and iteration % log_interval == 0: + self._dump_logs() + callback.on_training_end() return self diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 1e0f9e6c9..2641b915f 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -268,9 +268,12 @@ def learn( if continue_training is False: break - iteration += 1 self._update_current_progress_remaining(self.num_timesteps, total_timesteps) + self.train() + + iteration += 1 + # Display training infos if log_interval is not None and iteration % log_interval == 0: assert self.ep_info_buffer is not None @@ -285,8 +288,6 @@ def learn( self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard") self.logger.dump(step=self.num_timesteps) - self.train() - callback.on_training_end() return self diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..8bc43dce1 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -229,7 +229,7 @@ def test_ppo_warnings(): # in that case with pytest.warns(UserWarning, match="there will be a truncated mini-batch of size 1"): model = PPO("MlpPolicy", "Pendulum-v1", n_steps=64, batch_size=63, verbose=1) - model.learn(64) + model.learn(64, log_interval=2) loss = model.logger.name_to_value["train/loss"] assert loss > 0