diff --git a/sb3_contrib/tqc/policies.py b/sb3_contrib/tqc/policies.py index e8022f9c..141f994b 100644 --- a/sb3_contrib/tqc/policies.py +++ b/sb3_contrib/tqc/policies.py @@ -213,6 +213,8 @@ def __init__( n_quantiles: int = 25, n_critics: int = 2, share_features_extractor: bool = False, + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -230,7 +232,14 @@ def __init__( self.quantiles_total = n_quantiles * n_critics for i in range(n_critics): - qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn) + qf_net = create_mlp( + features_dim + action_dim, + n_quantiles, + net_arch, + activation_fn, + dropout_rate=dropout_rate, + layer_norm=layer_norm, + ) qf_net = nn.Sequential(*qf_net) self.add_module(f"qf{i}", qf_net) self.q_networks.append(qf_net) @@ -298,6 +307,9 @@ def __init__( n_quantiles: int = 25, n_critics: int = 2, share_features_extractor: bool = False, + # For the critic only + dropout_rate: float = 0.0, + layer_norm: bool = False, ): super().__init__( observation_space, @@ -341,6 +353,8 @@ def __init__( "n_critics": n_critics, "net_arch": critic_arch, "share_features_extractor": share_features_extractor, + "dropout_rate": dropout_rate, + "layer_norm": layer_norm, } self.critic_kwargs.update(tqc_kwargs) self.actor, self.actor_target = None, None diff --git a/sb3_contrib/tqc/tqc.py b/sb3_contrib/tqc/tqc.py index df65496a..fad65390 100644 --- a/sb3_contrib/tqc/tqc.py +++ b/sb3_contrib/tqc/tqc.py @@ -90,6 +90,7 @@ def __init__( replay_buffer_class: Optional[ReplayBuffer] = None, replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, + policy_delay: int = 1, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, target_entropy: Union[str, float] = "auto", @@ -142,6 +143,7 @@ def __init__( self.target_update_interval = target_update_interval self.ent_coef_optimizer = None self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net + self.policy_delay = policy_delay if _init_setup_model: self._setup_model() @@ -202,6 +204,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: actor_losses, critic_losses = [], [] for gradient_step in range(gradient_steps): + self._n_updates += 1 + update_actor = self._n_updates % self.policy_delay == 0 # Sample replay buffer replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) @@ -219,8 +223,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # so we don't change it with other losses # see https://github.com/rail-berkeley/softlearning/issues/60 ent_coef = th.exp(self.log_ent_coef.detach()) - ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() - ent_coef_losses.append(ent_coef_loss.item()) + if update_actor: + ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean() + ent_coef_losses.append(ent_coef_loss.item()) else: ent_coef = self.ent_coef_tensor @@ -239,6 +244,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations) # Compute and cut quantiles at the next state # batch x nets x quantiles + # Note: in dropq dropout seems to be on for target net too next_quantiles = self.critic_target(replay_data.next_observations, next_actions) # Sort and drop top k quantiles to control overestimation. @@ -264,14 +270,15 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self.critic.optimizer.step() # Compute actor loss - qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True) - actor_loss = (ent_coef * log_prob - qf_pi).mean() - actor_losses.append(actor_loss.item()) + if update_actor: + qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True) + actor_loss = (ent_coef * log_prob - qf_pi).mean() + actor_losses.append(actor_loss.item()) - # Optimize the actor - self.actor.optimizer.zero_grad() - actor_loss.backward() - self.actor.optimizer.step() + # Optimize the actor + self.actor.optimizer.zero_grad() + actor_loss.backward() + self.actor.optimizer.step() # Update target networks if gradient_step % self.target_update_interval == 0: @@ -279,8 +286,6 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) - self._n_updates += gradient_steps - self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") self.logger.record("train/ent_coef", np.mean(ent_coefs)) self.logger.record("train/actor_loss", np.mean(actor_losses)) diff --git a/tests/test_run.py b/tests/test_run.py index 09a00904..69a18058 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -8,6 +8,17 @@ from sb3_contrib.common.vec_env import AsyncEval +def test_dropq(): + model = TQC( + "MlpPolicy", + "Pendulum-v1", + policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005), + verbose=1, + buffer_size=250, + ) + model.learn(total_timesteps=300) + + @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) def test_tqc(ent_coef): with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated