diff --git a/sbx/common/jax_layers.py b/sbx/common/jax_layers.py index 2323019..c3ebf4e 100644 --- a/sbx/common/jax_layers.py +++ b/sbx/common/jax_layers.py @@ -147,11 +147,10 @@ def __call__(self, x, use_running_average: Optional[bool] = None): ) if use_running_average: - mean, var = ra_mean.value, ra_var.value - custom_mean = mean - custom_var = var + custom_mean = ra_mean.value + custom_var = ra_var.value else: - mean, var = _compute_stats( + batch_mean, batch_var = _compute_stats( x, reduction_axes, dtype=self.dtype, @@ -159,34 +158,35 @@ def __call__(self, x, use_running_average: Optional[bool] = None): axis_index_groups=self.axis_index_groups, # use_fast_variance=self.use_fast_variance, ) - custom_mean = mean - custom_var = var - if not self.is_initializing(): - std = jnp.sqrt(var + self.epsilon) + if self.is_initializing(): + custom_mean = batch_mean + custom_var = batch_var + else: + std = jnp.sqrt(batch_var + self.epsilon) ra_std = jnp.sqrt(ra_var.value + self.epsilon) # scale r = jax.lax.stop_gradient(std / ra_std) r = jnp.clip(r, 1 / r_max.value, r_max.value) # bias - d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std) + d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std) d = jnp.clip(d, -d_max.value, d_max.value) # BatchNorm normalization, using minibatch stats and running average stats # Because we use _normalize, this is equivalent to # ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma # where sigma = sqrt(var) - affine_mean = mean - d * jnp.sqrt(var) / r - affine_var = var / (r**2) + affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r + affine_var = batch_var / (r**2) # Note: in the original paper, after some warmup phase (batch norm phase of 5k steps) # the constraints are linearly relaxed to r_max/d_max over 40k steps # Here we only have a warmup phase is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32) - custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * custom_var - custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * custom_mean + custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean + custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var - ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * mean - ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * var + ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean + ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var steps.value += 1 return _normalize(