Skip to content

Commit

Permalink
Cleaner variable names for BatchReNorm
Browse files Browse the repository at this point in the history
Co-authored-by: Jan Schneider <[email protected]>
  • Loading branch information
araffin and jan1854 authored Mar 31, 2024
1 parent de80349 commit 55d60e3
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions sbx/common/jax_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,46 +147,46 @@ def __call__(self, x, use_running_average: Optional[bool] = None):
)

if use_running_average:
mean, var = ra_mean.value, ra_var.value
custom_mean = mean
custom_var = var
custom_mean = ra_mean.value
custom_var = ra_var.value
else:
mean, var = _compute_stats(
batch_mean, batch_var = _compute_stats(
x,
reduction_axes,
dtype=self.dtype,
axis_name=self.axis_name if not self.is_initializing() else None,
axis_index_groups=self.axis_index_groups,
# use_fast_variance=self.use_fast_variance,
)
custom_mean = mean
custom_var = var
if not self.is_initializing():
std = jnp.sqrt(var + self.epsilon)
if self.is_initializing():
custom_mean = batch_mean
custom_var = batch_var
else:
std = jnp.sqrt(batch_var + self.epsilon)
ra_std = jnp.sqrt(ra_var.value + self.epsilon)
# scale
r = jax.lax.stop_gradient(std / ra_std)
r = jnp.clip(r, 1 / r_max.value, r_max.value)
# bias
d = jax.lax.stop_gradient((mean - ra_mean.value) / ra_std)
d = jax.lax.stop_gradient((batch_mean - ra_mean.value) / ra_std)
d = jnp.clip(d, -d_max.value, d_max.value)

# BatchNorm normalization, using minibatch stats and running average stats
# Because we use _normalize, this is equivalent to
# ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma
# where sigma = sqrt(var)
affine_mean = mean - d * jnp.sqrt(var) / r
affine_var = var / (r**2)
affine_mean = batch_mean - d * jnp.sqrt(batch_var) / r
affine_var = batch_var / (r**2)

# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
is_warmed_up = jnp.greater_equal(steps.value, self.warm_up_steps).astype(jnp.float32)
custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * custom_var
custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * custom_mean
custom_mean = is_warmed_up * affine_mean + (1.0 - is_warmed_up) * batch_mean
custom_var = is_warmed_up * affine_var + (1.0 - is_warmed_up) * batch_var

ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * mean
ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * var
ra_mean.value = self.momentum * ra_mean.value + (1.0 - self.momentum) * batch_mean
ra_var.value = self.momentum * ra_var.value + (1.0 - self.momentum) * batch_var
steps.value += 1

return _normalize(
Expand Down

0 comments on commit 55d60e3

Please sign in to comment.