diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 00cc0186..291dad87 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -385,13 +385,6 @@ def training_epoch_with_timing( specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))), env_steps=0) - if num_timesteps == 0: - return ( - make_policy, - (training_state.normalizer_params, training_state.params), - {}, - ) - if ( restore_checkpoint_path is not None and epath.Path(restore_checkpoint_path).exists() @@ -406,6 +399,13 @@ def training_epoch_with_timing( normalizer_params=normalizer_params, params=init_params ) + if num_timesteps == 0: + return ( + make_policy, + (training_state.normalizer_params, training_state.params), + {}, + ) + training_state = jax.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use]) diff --git a/docs/release-notes/next-release.md b/docs/release-notes/next-release.md index 60a8358e..7e5c6680 100644 --- a/docs/release-notes/next-release.md +++ b/docs/release-notes/next-release.md @@ -1,3 +1,4 @@ # Brax Release Notes -* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is. \ No newline at end of file +* Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is. +* Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0. \ No newline at end of file