diff --git a/brax/experimental/barkour/tutorial.ipynb b/brax/experimental/barkour/tutorial.ipynb index 34061b5f..6b0d5df1 100644 --- a/brax/experimental/barkour/tutorial.ipynb +++ b/brax/experimental/barkour/tutorial.ipynb @@ -814,7 +814,7 @@ }, "outputs": [], "source": [ - "HTML(html.render(eval_env.sys.replace(dt=eval_env.dt), rollout))" + "HTML(html.render(eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}), rollout))" ] } ], diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 93fd050f..90ce5cd4 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -35,11 +35,13 @@ from brax.training.types import Params from brax.training.types import PRNGKey from brax.v1 import envs as envs_v1 +from etils import epath import flax import jax import jax.numpy as jnp import numpy as np import optax +from orbax import checkpoint as ocp InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] @@ -103,6 +105,7 @@ def train( randomization_fn: Optional[ Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]] ] = None, + restore_checkpoint_path: Optional[str] = None, ): """PPO training. @@ -122,8 +125,8 @@ def train( num_eval_envs: the number of envs to use for evluation. Each env will run 1 episode, and all envs run in parallel during eval. learning_rate: learning rate for ppo loss - entropy_cost: entropy reward for ppo loss, higher values increase entropy - of the policy + entropy_cost: entropy reward for ppo loss, higher values increase entropy of + the policy discounting: discounting rate seed: random seed unroll_length: the number of timesteps to unroll in each environment. The @@ -151,6 +154,7 @@ def train( saving policy checkpoints randomization_fn: a user-defined callback function that generates randomized environments + restore_checkpoint_path: the path used to restore previous model params Returns: Tuple of (make_policy function, network params, metrics) @@ -365,15 +369,40 @@ def training_epoch_with_timing( } return training_state, env_state, metrics # pytype: disable=bad-return-type # py311-upgrade + # Initialize model params and training state. init_params = ppo_losses.PPONetworkParams( policy=ppo_network.policy_network.init(key_policy), - value=ppo_network.value_network.init(key_value)) + value=ppo_network.value_network.init(key_value), + ) + training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars params=init_params, normalizer_params=running_statistics.init_state( specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))), env_steps=0) + + if num_timesteps == 0: + return ( + make_policy, + (training_state.normalizer_params, training_state.params), + {}, + ) + + if ( + restore_checkpoint_path is not None + and epath.Path(restore_checkpoint_path).exists() + ): + logging.info('restoring from checkpoint %s', restore_checkpoint_path) + orbax_checkpointer = ocp.PyTreeCheckpointer() + target = training_state.normalizer_params, init_params + (normalizer_params, init_params) = orbax_checkpointer.restore( + restore_checkpoint_path, item=target + ) + training_state = training_state.replace( + normalizer_params=normalizer_params, params=init_params + ) + training_state = jax.device_put_replicated( training_state, jax.local_devices()[:local_devices_to_use]) @@ -439,7 +468,8 @@ def training_epoch_with_timing( logging.info(metrics) progress_fn(current_step, metrics) params = _unpmap( - (training_state.normalizer_params, training_state.params.policy)) + (training_state.normalizer_params, training_state.params) + ) policy_params_fn(current_step, make_policy, params) total_steps = current_step diff --git a/brax/training/learner.py b/brax/training/learner.py index d3fce1be..9dd3c3c5 100644 --- a/brax/training/learner.py +++ b/brax/training/learner.py @@ -272,7 +272,7 @@ def do_rollout(rng): for i in range(FLAGS.num_videos): html_path = f'{FLAGS.logdir}/saved_videos/trajectory_{i:04d}.html' if isinstance(env, envs.Env): - html.save(html_path, env.sys.replace(dt=env.dt), trajectories[i]) + html.save(html_path, env.sys.tree_replace({'opt.timestep': env.dt}), trajectories[i]) else: html_v1.save_html(html_path, env.sys, trajectories[i], make_dir=True) elif FLAGS.num_videos > 0: diff --git a/docs/release-notes/next-release.md b/docs/release-notes/next-release.md index ca4ecb57..d1d52cac 100644 --- a/docs/release-notes/next-release.md +++ b/docs/release-notes/next-release.md @@ -1 +1,4 @@ # Brax Release Notes + +* Modify `policy_params_fn` in the PPO implementation to take in the full model params. This can be used for checkpointing models. +* Add `restore_checkpoint_path` in PPO implementation. diff --git a/notebooks/training.ipynb b/notebooks/training.ipynb index bde1e0c9..b71e640d 100644 --- a/notebooks/training.ipynb +++ b/notebooks/training.ipynb @@ -424,7 +424,7 @@ " act, _ = jit_inference_fn(state.obs, act_rng)\n", " state = jit_env_step(state, act)\n", "\n", - "HTML(html.render(env.sys.replace(dt=env.dt), rollout))" + "HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))" ] }, { @@ -437,7 +437,7 @@ ] } ], - "metadata": { + "metadata": { "colab": { "provenance": [], "toc_visible": true diff --git a/setup.py b/setup.py index 97880626..270daf17 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ "mujoco-mjx", "numpy", "optax", + "orbax-checkpoint", # TODO: remove pytinyrenderer after dropping legacy v1 code "Pillow", "pytinyrenderer",