Skip to content

Commit

Permalink
Add gradient clipping to PPO. (#561)
Browse files Browse the repository at this point in the history
* Add gradient clipping to PPO.

* Add TODO.
  • Loading branch information
kevinzakka authored Nov 27, 2024
1 parent f43727e commit 7f9f808
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def train(
Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
] = None,
restore_checkpoint_path: Optional[str] = None,
max_grad_norm: Optional[float] = None,
):
"""PPO training.
Expand Down Expand Up @@ -158,6 +159,7 @@ def train(
randomization_fn: a user-defined callback function that generates randomized
environments
restore_checkpoint_path: the path used to restore previous model params
max_grad_norm: gradient clipping norm value. If None, no clipping is done
Returns:
Tuple of (make_policy function, network params, metrics)
Expand Down Expand Up @@ -241,7 +243,14 @@ def train(
preprocess_observations_fn=normalize)
make_policy = ppo_networks.make_inference_fn(ppo_network)

optimizer = optax.adam(learning_rate=learning_rate)
if max_grad_norm is not None:
# TODO(btaba): Move gradient clipping to `training/gradients.py`.
optimizer = optax.chain(
optax.clip_by_global_norm(max_grad_norm),
optax.adam(learning_rate=learning_rate)
)
else:
optimizer = optax.adam(learning_rate=learning_rate)

loss_fn = functools.partial(
ppo_losses.compute_ppo_loss,
Expand Down

0 comments on commit 7f9f808

Please sign in to comment.