diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index b6fd324b..a7839c47 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -107,6 +107,7 @@ def train( Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]] ] = None, restore_checkpoint_path: Optional[str] = None, + max_grad_norm: Optional[float] = None, ): """PPO training. @@ -158,6 +159,7 @@ def train( randomization_fn: a user-defined callback function that generates randomized environments restore_checkpoint_path: the path used to restore previous model params + max_grad_norm: gradient clipping norm value. If None, no clipping is done Returns: Tuple of (make_policy function, network params, metrics) @@ -241,7 +243,14 @@ def train( preprocess_observations_fn=normalize) make_policy = ppo_networks.make_inference_fn(ppo_network) - optimizer = optax.adam(learning_rate=learning_rate) + if max_grad_norm is not None: + # TODO(btaba): Move gradient clipping to `training/gradients.py`. + optimizer = optax.chain( + optax.clip_by_global_norm(max_grad_norm), + optax.adam(learning_rate=learning_rate) + ) + else: + optimizer = optax.adam(learning_rate=learning_rate) loss_fn = functools.partial( ppo_losses.compute_ppo_loss,