From de803494e2ebdb6f1cb93af9a2f99f182a1c4e3c Mon Sep 17 00:00:00 2001 From: Jan Schneider Date: Sat, 30 Mar 2024 08:26:07 +0100 Subject: [PATCH] Clean-up: Removed unused variables and fixed typo --- sbx/common/jax_layers.py | 2 -- sbx/crossq/crossq.py | 4 +--- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index 5733d2c..2323019 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -162,8 +162,6 @@ def __call__(self, x, use_running_average: Optional[bool] = None): custom_mean = mean custom_var = var if not self.is_initializing(): - r = jnp.array(1.0) - d = jnp.array(0.0) std = jnp.sqrt(var + self.epsilon) ra_std = jnp.sqrt(ra_var.value + self.epsilon) # scale diff --git a/sbx/crossq/crossq.py b/sbx/crossq/crossq.py index 52d040c..bb2d1ab 100644 --- a/sbx/crossq/crossq.py +++ b/sbx/crossq/crossq.py @@ -208,7 +208,6 @@ def train(self, gradient_steps: int, batch_size: int) -> None: (actor_loss_value, qf_loss_value, ent_coef_value), ) = self._train( self.gamma, - self.tau, self.target_entropy, gradient_steps, data, @@ -260,7 +259,7 @@ def mse_loss( # # This has two reasons: # 1. According to the paper obs/actions and next_obs/next_state_actions are differently - # distributed which is the reason why "naively" appling Batch Normalization in SAC fails. + # distributed which is the reason why "naively" applying Batch Normalization in SAC fails. # The batch statistics have to instead be calculated for the mixture distribution of obs/next_obs # and actions/next_state_actions. Otherwise, next_obs/next_state_actions are perceived as # out-of-distribution to the Batch Normalization layer, since running statistics are only polyak averaged @@ -385,7 +384,6 @@ def update_actor_and_temperature( def _train( cls, gamma: float, - tau: float, target_entropy: ArrayLike, gradient_steps: int, data: ReplayBufferSamplesNp,