Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 637964477
Change-Id: I3c10221cd5f4452ba0ab6effc4d5d444aebd70b8
  • Loading branch information
Brax Team authored and btaba committed May 29, 2024
1 parent e1af19f commit b164655
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 8 deletions.
2 changes: 1 addition & 1 deletion brax/experimental/barkour/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@
},
"outputs": [],
"source": [
"HTML(html.render(eval_env.sys.replace(dt=eval_env.dt), rollout))"
"HTML(html.render(eval_env.sys.tree_replace({'opt.timestep': eval_env.dt}), rollout))"
]
}
],
Expand Down
38 changes: 34 additions & 4 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
from brax.training.types import Params
from brax.training.types import PRNGKey
from brax.v1 import envs as envs_v1
from etils import epath
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from orbax import checkpoint as ocp


InferenceParams = Tuple[running_statistics.NestedMeanStd, Params]
Expand Down Expand Up @@ -103,6 +105,7 @@ def train(
randomization_fn: Optional[
Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]
] = None,
restore_checkpoint_path: Optional[str] = None,
):
"""PPO training.
Expand All @@ -122,8 +125,8 @@ def train(
num_eval_envs: the number of envs to use for evluation. Each env will run 1
episode, and all envs run in parallel during eval.
learning_rate: learning rate for ppo loss
entropy_cost: entropy reward for ppo loss, higher values increase entropy
of the policy
entropy_cost: entropy reward for ppo loss, higher values increase entropy of
the policy
discounting: discounting rate
seed: random seed
unroll_length: the number of timesteps to unroll in each environment. The
Expand Down Expand Up @@ -151,6 +154,7 @@ def train(
saving policy checkpoints
randomization_fn: a user-defined callback function that generates randomized
environments
restore_checkpoint_path: the path used to restore previous model params
Returns:
Tuple of (make_policy function, network params, metrics)
Expand Down Expand Up @@ -365,15 +369,40 @@ def training_epoch_with_timing(
}
return training_state, env_state, metrics # pytype: disable=bad-return-type # py311-upgrade

# Initialize model params and training state.
init_params = ppo_losses.PPONetworkParams(
policy=ppo_network.policy_network.init(key_policy),
value=ppo_network.value_network.init(key_value))
value=ppo_network.value_network.init(key_value),
)

training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
params=init_params,
normalizer_params=running_statistics.init_state(
specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32'))),
env_steps=0)

if num_timesteps == 0:
return (
make_policy,
(training_state.normalizer_params, training_state.params),
{},
)

if (
restore_checkpoint_path is not None
and epath.Path(restore_checkpoint_path).exists()
):
logging.info('restoring from checkpoint %s', restore_checkpoint_path)
orbax_checkpointer = ocp.PyTreeCheckpointer()
target = training_state.normalizer_params, init_params
(normalizer_params, init_params) = orbax_checkpointer.restore(
restore_checkpoint_path, item=target
)
training_state = training_state.replace(
normalizer_params=normalizer_params, params=init_params
)

training_state = jax.device_put_replicated(
training_state,
jax.local_devices()[:local_devices_to_use])
Expand Down Expand Up @@ -439,7 +468,8 @@ def training_epoch_with_timing(
logging.info(metrics)
progress_fn(current_step, metrics)
params = _unpmap(
(training_state.normalizer_params, training_state.params.policy))
(training_state.normalizer_params, training_state.params)
)
policy_params_fn(current_step, make_policy, params)

total_steps = current_step
Expand Down
2 changes: 1 addition & 1 deletion brax/training/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def do_rollout(rng):
for i in range(FLAGS.num_videos):
html_path = f'{FLAGS.logdir}/saved_videos/trajectory_{i:04d}.html'
if isinstance(env, envs.Env):
html.save(html_path, env.sys.replace(dt=env.dt), trajectories[i])
html.save(html_path, env.sys.tree_replace({'opt.timestep': env.dt}), trajectories[i])
else:
html_v1.save_html(html_path, env.sys, trajectories[i], make_dir=True)
elif FLAGS.num_videos > 0:
Expand Down
3 changes: 3 additions & 0 deletions docs/release-notes/next-release.md
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
# Brax Release Notes

* Modify `policy_params_fn` in the PPO implementation to take in the full model params. This can be used for checkpointing models.
* Add `restore_checkpoint_path` in PPO implementation.
4 changes: 2 additions & 2 deletions notebooks/training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@
" act, _ = jit_inference_fn(state.obs, act_rng)\n",
" state = jit_env_step(state, act)\n",
"\n",
"HTML(html.render(env.sys.replace(dt=env.dt), rollout))"
"HTML(html.render(env.sys.tree_replace({'opt.timestep': env.dt}), rollout))"
]
},
{
Expand All @@ -437,7 +437,7 @@
]
}
],
"metadata": {
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"mujoco-mjx",
"numpy",
"optax",
"orbax-checkpoint",
# TODO: remove pytinyrenderer after dropping legacy v1 code
"Pillow",
"pytinyrenderer",
Expand Down

0 comments on commit b164655

Please sign in to comment.