From bfbe00be154b223e94952870ac8f5b97a1fc558f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Thu, 8 Feb 2024 10:42:21 +0100 Subject: [PATCH] Fix off-by-one and improve type annotation --- sbx/sac/sac.py | 13 +++++++------ sbx/td3/td3.py | 13 +++++++++---- sbx/tqc/tqc.py | 19 ++++++++++++------- 3 files changed, 28 insertions(+), 17 deletions(-) diff --git a/sbx/sac/sac.py b/sbx/sac/sac.py index 4835bdc..fc73455 100644 --- a/sbx/sac/sac.py +++ b/sbx/sac/sac.py @@ -1,6 +1,7 @@ from functools import partial from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -212,7 +213,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: gradient_steps, data, self.policy_delay, - self._n_updates % self.policy_delay, + (self.n_updates + 1) % self.policy_delay, self.policy.qf_state, self.policy.actor_state, self.ent_coef_state, @@ -259,7 +260,7 @@ def update_critic( # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - def mse_loss(params, dropout_key): + def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() @@ -284,7 +285,7 @@ def update_actor( ): key, dropout_key, noise_key = jax.random.split(key, 3) - def actor_loss(params): + def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) @@ -308,16 +309,16 @@ def actor_loss(params): @staticmethod @jax.jit - def soft_update(tau: float, qf_state: RLTrainState): + def soft_update(tau: float, qf_state: RLTrainState) -> RLTrainState: qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) return qf_state @staticmethod @jax.jit def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): - def temperature_loss(temp_params): + def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) diff --git a/sbx/td3/td3.py b/sbx/td3/td3.py index eac8204..69de016 100644 --- a/sbx/td3/td3.py +++ b/sbx/td3/td3.py @@ -1,6 +1,7 @@ from functools import partial from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +import flax import jax import jax.numpy as jnp import numpy as np @@ -151,7 +152,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: gradient_steps, data, self.policy_delay, - self._n_updates % self.policy_delay, + (self.n_updates + 1) % self.policy_delay, self.target_policy_noise, self.target_noise_clip, self.policy.qf_state, @@ -197,7 +198,7 @@ def update_critic( # shape is (batch_size, 1) target_q_values = rewards.reshape(-1, 1) + (1 - dones.reshape(-1, 1)) * gamma * next_q_values - def mse_loss(params, dropout_key): + def mse_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # shape is (n_critics, batch_size, 1) current_q_values = qf_state.apply_fn(params, observations, actions, rngs={"dropout": dropout_key}) return 0.5 * ((target_q_values - current_q_values) ** 2).mean(axis=1).sum() @@ -221,7 +222,7 @@ def update_actor( ): key, dropout_key = jax.random.split(key, 2) - def actor_loss(params): + def actor_loss(params: flax.core.FrozenDict) -> jax.Array: actor_actions = actor_state.apply_fn(params, observations) qf_pi = qf_state.apply_fn( @@ -242,7 +243,7 @@ def actor_loss(params): @staticmethod @jax.jit - def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState): + def soft_update(tau: float, qf_state: RLTrainState, actor_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]: qf_state = qf_state.replace(target_params=optax.incremental_update(qf_state.params, qf_state.target_params, tau)) actor_state = actor_state.replace( target_params=optax.incremental_update(actor_state.params, actor_state.target_params, tau) @@ -279,6 +280,8 @@ def _train( } def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry actor_state = carry["actor_state"] qf_state = carry["qf_state"] key = carry["key"] @@ -309,7 +312,9 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: (actor_state, qf_state, actor_loss_value, key) = jax.lax.cond( (policy_delay_offset + i) % policy_delay == 0, + # If True: cls.update_actor, + # If False: lambda *_: (actor_state, qf_state, info["actor_loss"], key), actor_state, qf_state, diff --git a/sbx/tqc/tqc.py b/sbx/tqc/tqc.py index ce50e95..d2beca0 100644 --- a/sbx/tqc/tqc.py +++ b/sbx/tqc/tqc.py @@ -1,6 +1,7 @@ from functools import partial from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +import flax import flax.linen as nn import jax import jax.numpy as jnp @@ -215,7 +216,7 @@ def train(self, gradient_steps: int, batch_size: int) -> None: self.policy.n_target_quantiles, data, self.policy_delay, - self._n_updates % self.policy_delay, + (self.n_updates + 1) % self.policy_delay, self.policy.qf1_state, self.policy.qf2_state, self.policy.actor_state, @@ -284,9 +285,9 @@ def update_critic( # Make target_quantiles broadcastable to (batch_size, n_quantiles, n_target_quantiles). target_quantiles = jnp.expand_dims(target_quantiles, axis=1) - def huber_quantile_loss(params, noise_key): + def huber_quantile_loss(params: flax.core.FrozenDict, dropout_key: jax.Array) -> jax.Array: # Compute huber quantile loss - current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": noise_key}) + current_quantiles = qf1_state.apply_fn(params, observations, actions, True, rngs={"dropout": dropout_key}) # convert to shape: (batch_size, n_quantiles, 1) for broadcast current_quantiles = jnp.expand_dims(current_quantiles, axis=-1) @@ -327,7 +328,7 @@ def update_actor( ): key, dropout_key_1, dropout_key_2, noise_key = jax.random.split(key, 4) - def actor_loss(params): + def actor_loss(params: flax.core.FrozenDict) -> Tuple[jax.Array, jax.Array]: dist = actor_state.apply_fn(params, observations) actor_actions = dist.sample(seed=noise_key) log_prob = dist.log_prob(actor_actions).reshape(-1, 1) @@ -364,7 +365,7 @@ def actor_loss(params): @staticmethod @jax.jit - def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState): + def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState) -> Tuple[RLTrainState, RLTrainState]: qf1_state = qf1_state.replace(target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, tau)) qf2_state = qf2_state.replace(target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, tau)) return qf1_state, qf2_state @@ -372,10 +373,10 @@ def soft_update(tau: float, qf1_state: RLTrainState, qf2_state: RLTrainState): @staticmethod @jax.jit def update_temperature(target_entropy: ArrayLike, ent_coef_state: TrainState, entropy: float): - def temperature_loss(temp_params): + def temperature_loss(temp_params: flax.core.FrozenDict) -> jax.Array: ent_coef_value = ent_coef_state.apply_fn({"params": temp_params}) # ent_coef_loss = (jnp.log(ent_coef_value) * (entropy - target_entropy)).mean() - ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() + ent_coef_loss = ent_coef_value * (entropy - target_entropy).mean() # type: ignore[union-attr] return ent_coef_loss ent_coef_loss, grads = jax.value_and_grad(temperature_loss)(ent_coef_state.params) @@ -444,6 +445,8 @@ def _train( } def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: + # Note: this method must be defined inline because + # `fori_loop` expect a signature fn(index, carry) -> carry actor_state = carry["actor_state"] qf1_state = carry["qf1_state"] qf2_state = carry["qf2_state"] @@ -477,7 +480,9 @@ def one_update(i: int, carry: Dict[str, Any]) -> Dict[str, Any]: (actor_state, (qf1_state, qf2_state), ent_coef_state, actor_loss_value, ent_coef_loss_value, key) = jax.lax.cond( (policy_delay_offset + i) % policy_delay == 0, + # If True: cls.update_actor_and_temperature, + # If False: lambda *_: ( actor_state, (qf1_state, qf2_state),