Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/dropq #100

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion sb3_contrib/tqc/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def __init__(
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = False,
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand All @@ -230,7 +232,14 @@ def __init__(
self.quantiles_total = n_quantiles * n_critics

for i in range(n_critics):
qf_net = create_mlp(features_dim + action_dim, n_quantiles, net_arch, activation_fn)
qf_net = create_mlp(
features_dim + action_dim,
n_quantiles,
net_arch,
activation_fn,
dropout_rate=dropout_rate,
layer_norm=layer_norm,
)
qf_net = nn.Sequential(*qf_net)
self.add_module(f"qf{i}", qf_net)
self.q_networks.append(qf_net)
Expand Down Expand Up @@ -298,6 +307,9 @@ def __init__(
n_quantiles: int = 25,
n_critics: int = 2,
share_features_extractor: bool = False,
# For the critic only
dropout_rate: float = 0.0,
layer_norm: bool = False,
):
super().__init__(
observation_space,
Expand Down Expand Up @@ -341,6 +353,8 @@ def __init__(
"n_critics": n_critics,
"net_arch": critic_arch,
"share_features_extractor": share_features_extractor,
"dropout_rate": dropout_rate,
"layer_norm": layer_norm,
}
self.critic_kwargs.update(tqc_kwargs)
self.actor, self.actor_target = None, None
Expand Down
27 changes: 16 additions & 11 deletions sb3_contrib/tqc/tqc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
replay_buffer_class: Optional[ReplayBuffer] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
policy_delay: int = 1,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Union[str, float] = "auto",
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.target_update_interval = target_update_interval
self.ent_coef_optimizer = None
self.top_quantiles_to_drop_per_net = top_quantiles_to_drop_per_net
self.policy_delay = policy_delay

if _init_setup_model:
self._setup_model()
Expand Down Expand Up @@ -202,6 +204,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
actor_losses, critic_losses = [], []

for gradient_step in range(gradient_steps):
self._n_updates += 1
update_actor = self._n_updates % self.policy_delay == 0
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

Expand All @@ -219,8 +223,9 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
# so we don't change it with other losses
# see https://github.com/rail-berkeley/softlearning/issues/60
ent_coef = th.exp(self.log_ent_coef.detach())
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
if update_actor:
ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
ent_coef_losses.append(ent_coef_loss.item())
else:
ent_coef = self.ent_coef_tensor

Expand All @@ -239,6 +244,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
# Compute and cut quantiles at the next state
# batch x nets x quantiles
# Note: in dropq dropout seems to be on for target net too
next_quantiles = self.critic_target(replay_data.next_observations, next_actions)

# Sort and drop top k quantiles to control overestimation.
Expand All @@ -264,23 +270,22 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
self.critic.optimizer.step()

# Compute actor loss
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())
if update_actor:
qf_pi = self.critic(replay_data.observations, actions_pi).mean(dim=2).mean(dim=1, keepdim=True)
actor_loss = (ent_coef * log_prob - qf_pi).mean()
actor_losses.append(actor_loss.item())

# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()
# Optimize the actor
self.actor.optimizer.zero_grad()
actor_loss.backward()
self.actor.optimizer.step()

# Update target networks
if gradient_step % self.target_update_interval == 0:
polyak_update(self.critic.parameters(), self.critic_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self._n_updates += gradient_steps

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/ent_coef", np.mean(ent_coefs))
self.logger.record("train/actor_loss", np.mean(actor_losses))
Expand Down
11 changes: 11 additions & 0 deletions tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
from sb3_contrib.common.vec_env import AsyncEval


def test_dropq():
model = TQC(
"MlpPolicy",
"Pendulum-v1",
policy_kwargs=dict(net_arch=[64, 64], layer_norm=True, dropout_rate=0.005),
verbose=1,
buffer_size=250,
)
model.learn(total_timesteps=300)


@pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"])
def test_tqc(ent_coef):
with pytest.warns(DeprecationWarning): # `create_eval_env` and `eval_freq` are deprecated
Expand Down