From 241f9bc5bbd003f9cfc9ded7613388e2fe125af6 Mon Sep 17 00:00:00 2001 From: Brax Team Date: Tue, 4 Feb 2025 23:33:30 -0800 Subject: [PATCH] Internal change PiperOrigin-RevId: 723378503 Change-Id: I1b63d2f9dc5f49009bbafe4875baadfe51ed9afd --- brax/training/agents/ppo/checkpoint.py | 102 +++----------- brax/training/agents/ppo/checkpoint_test.py | 4 +- brax/training/agents/ppo/train.py | 16 ++- brax/training/agents/sac/checkpoint.py | 86 ++++++++++++ brax/training/agents/sac/checkpoint_test.py | 99 ++++++++++++++ brax/training/agents/sac/train.py | 22 +++- brax/training/checkpoint.py | 139 ++++++++++++++++++++ 7 files changed, 379 insertions(+), 89 deletions(-) create mode 100644 brax/training/agents/sac/checkpoint.py create mode 100644 brax/training/agents/sac/checkpoint_test.py create mode 100644 brax/training/checkpoint.py diff --git a/brax/training/agents/ppo/checkpoint.py b/brax/training/agents/ppo/checkpoint.py index e605109e..b27b1c61 100644 --- a/brax/training/agents/ppo/checkpoint.py +++ b/brax/training/agents/ppo/checkpoint.py @@ -14,115 +14,53 @@ """Checkpointing for PPO.""" -import inspect import json -import logging -from typing import Any, Dict, Tuple, Union +from typing import Any, Union +from brax.training import checkpoint from brax.training import types -from brax.training.acme import running_statistics from brax.training.agents.ppo import networks as ppo_networks from etils import epath -from flax import linen -from flax.training import orbax_utils from ml_collections import config_dict -from orbax import checkpoint as ocp -_CONFIG_FNAME = 'config.json' - - -def _get_default_kwargs(func: Any) -> Dict[str, Any]: - """Returns the default kwargs of a function.""" - return { - p.name: p.default - for p in inspect.signature(func).parameters.values() - if p.default is not inspect.Parameter.empty - } - - -def ppo_config( - observation_size: types.ObservationSize, - action_size: int, - normalize_observations: bool, - network_factory: types.NetworkFactory[ppo_networks.PPONetworks], -) -> config_dict.ConfigDict: - """Returns a config dict for re-creating PPO params from a checkpoint.""" - config = config_dict.ConfigDict() - kwargs = _get_default_kwargs(network_factory) - - if ( - kwargs.get('preprocess_observations_fn') - != types.identity_observation_preprocessor - ): - raise ValueError( - 'preprocess_observations_fn must be identity_observation_preprocessor' - ) - del kwargs['preprocess_observations_fn'] - if kwargs.get('activation') != linen.swish: - raise ValueError('activation must be swish') - del kwargs['activation'] - - config.network_factory_kwargs = kwargs - config.normalize_observations = normalize_observations - config.observation_size = observation_size - config.action_size = action_size - return config +_CONFIG_FNAME = 'ppo_network_config.json' def save( path: Union[str, epath.Path], step: int, - params: Tuple[Any, ...], + params: Any, config: config_dict.ConfigDict, ): """Saves a checkpoint.""" - ckpt_path = epath.Path(path) / f'{step:012d}' - logging.info('saving checkpoint to %s', ckpt_path.as_posix()) - - if not ckpt_path.exists(): - ckpt_path.mkdir(parents=True) - - config_path = epath.Path(path) / _CONFIG_FNAME - if not config_path.exists(): - config_path.write_text(config.to_json()) - - orbax_checkpointer = ocp.PyTreeCheckpointer() - save_args = orbax_utils.save_args_from_target(params) - orbax_checkpointer.save(ckpt_path, params, force=True, save_args=save_args) + return checkpoint.save(path, step, params, config, _CONFIG_FNAME) def load( path: Union[str, epath.Path], ): - """Loads PPO checkpoint.""" - path = epath.Path(path) - if not path.exists(): - raise ValueError(f'PPO checkpoint path does not exist: {path.as_posix()}') + """Loads checkpoint.""" + return checkpoint.load(path) - logging.info('restoring from checkpoint %s', path.as_posix()) - orbax_checkpointer = ocp.PyTreeCheckpointer() - target = orbax_checkpointer.restore(path, item=None) - target[0] = running_statistics.RunningStatisticsState(**target[0]) - - return target +def network_config( + observation_size: types.ObservationSize, + action_size: int, + normalize_observations: bool, + network_factory: types.NetworkFactory[Union[ppo_networks.PPONetworks]], +) -> config_dict.ConfigDict: + """Returns a config dict for re-creating a network from a checkpoint.""" + return checkpoint.network_config( + observation_size, action_size, normalize_observations, network_factory + ) -def _get_network( +def _get_ppo_network( config: config_dict.ConfigDict, network_factory: types.NetworkFactory[ppo_networks.PPONetworks], ) -> ppo_networks.PPONetworks: """Generates a PPO network given config.""" - normalize = lambda x, y: x - if config.normalize_observations: - normalize = running_statistics.normalize - ppo_network = network_factory( - config.to_dict()['observation_size'], - config.action_size, - preprocess_observations_fn=normalize, - **config.network_factory_kwargs, - ) - return ppo_network + return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type def load_policy( @@ -142,7 +80,7 @@ def load_policy( config = config_dict.create(**json.loads(config_path.read_text())) params = load(path) - ppo_network = _get_network(config, network_factory) + ppo_network = _get_ppo_network(config, network_factory) make_inference_fn = ppo_networks.make_inference_fn(ppo_network) return make_inference_fn(params, deterministic=deterministic) diff --git a/brax/training/agents/ppo/checkpoint_test.py b/brax/training/agents/ppo/checkpoint_test.py index 9cc4e888..96e0d21c 100644 --- a/brax/training/agents/ppo/checkpoint_test.py +++ b/brax/training/agents/ppo/checkpoint_test.py @@ -38,7 +38,7 @@ def test_ppo_params_config(self): ppo_networks.make_ppo_networks, policy_hidden_layer_sizes=(16, 21, 13), ) - config = checkpoint.ppo_config( + config = checkpoint.network_config( action_size=3, observation_size=1, normalize_observations=True, @@ -57,7 +57,7 @@ def test_save_and_load_checkpoint(self): ppo_networks.make_ppo_networks, policy_hidden_layer_sizes=(16, 21, 13), ) - config = checkpoint.ppo_config( + config = checkpoint.network_config( observation_size=1, action_size=3, normalize_observations=True, diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 5481d01d..02a74f18 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -242,6 +242,7 @@ def train( # checkpointing save_checkpoint_path: Optional[str] = None, restore_checkpoint_path: Optional[str] = None, + restore_params: Optional[Any] = None, restore_value_fn: bool = True, ): """PPO training. @@ -306,6 +307,9 @@ def train( save_checkpoint_path: the path used to save checkpoints. If None, no checkpoints are saved. restore_checkpoint_path: the path used to restore previous model params + restore_params: raw network parameters to restore the TrainingState from. + These override `restore_checkpoint_path`. These paramaters can be obtained + from the return values of ppo.train(). restore_value_fn: whether to restore the value function from the checkpoint or use a random initialization @@ -422,7 +426,7 @@ def train( progress_fn=progress_fn, ) - ckpt_config = checkpoint.ppo_config( + ckpt_config = checkpoint.network_config( observation_size=obs_shape, action_size=env.action_size, normalize_observations=normalize_observations, @@ -619,6 +623,16 @@ def training_epoch_with_timing( ), ) + if restore_params is not None: + logging.info('Restoring TrainingState from `restore_params`.') + value_params = restore_params[2] if restore_value_fn else init_params.value + training_state = training_state.replace( + normalizer_params=restore_params[0], + params=training_state.params.replace( + policy=restore_params[1], value=value_params + ), + ) + if num_timesteps == 0: return ( make_policy, diff --git a/brax/training/agents/sac/checkpoint.py b/brax/training/agents/sac/checkpoint.py new file mode 100644 index 00000000..21e1a9f3 --- /dev/null +++ b/brax/training/agents/sac/checkpoint.py @@ -0,0 +1,86 @@ +# Copyright 2024 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpointing for SAC.""" + +import json +from typing import Any, Union + +from brax.training import checkpoint +from brax.training import types +from brax.training.agents.sac import networks as sac_networks +from etils import epath +from ml_collections import config_dict + +_CONFIG_FNAME = 'sac_network_config.json' + + +def save( + path: Union[str, epath.Path], + step: int, + params: Any, + config: config_dict.ConfigDict, +): + """Saves a checkpoint.""" + return checkpoint.save(path, step, params, config, _CONFIG_FNAME) + + +def load( + path: Union[str, epath.Path], +): + """Loads SAC checkpoint.""" + return checkpoint.load(path) + + +def network_config( + observation_size: types.ObservationSize, + action_size: int, + normalize_observations: bool, + network_factory: types.NetworkFactory[sac_networks.SACNetworks], +) -> config_dict.ConfigDict: + """Returns a config dict for re-creating a network from a checkpoint.""" + return checkpoint.network_config( + observation_size, action_size, normalize_observations, network_factory + ) + + +def _get_network( + config: config_dict.ConfigDict, + network_factory: types.NetworkFactory[sac_networks.SACNetworks], +) -> sac_networks.SACNetworks: + """Generates a SAC network given config.""" + return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type + + +def load_policy( + path: Union[str, epath.Path], + network_factory: types.NetworkFactory[ + sac_networks.SACNetworks + ] = sac_networks.make_sac_networks, + deterministic: bool = True, +): + """Loads policy inference function from SAC checkpoint.""" + path = epath.Path(path) + + config_path = path.parent / _CONFIG_FNAME + if not config_path.exists(): + raise ValueError(f'SAC config file not found at {config_path.as_posix()}') + + config = config_dict.create(**json.loads(config_path.read_text())) + + params = load(path) + sac_network = _get_network(config, network_factory) + make_inference_fn = sac_networks.make_inference_fn(sac_network) + + return make_inference_fn(params, deterministic=deterministic) diff --git a/brax/training/agents/sac/checkpoint_test.py b/brax/training/agents/sac/checkpoint_test.py new file mode 100644 index 00000000..842da49e --- /dev/null +++ b/brax/training/agents/sac/checkpoint_test.py @@ -0,0 +1,99 @@ +# Copyright 2024 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test PPO checkpointing.""" + +import functools + +from absl import flags +from absl.testing import absltest +from brax.training.acme import running_statistics +from brax.training.agents.sac import checkpoint +from brax.training.agents.sac import losses as sac_losses +from brax.training.agents.sac import networks as sac_networks +from etils import epath +import jax +from jax import numpy as jp + + +class CheckpointTest(absltest.TestCase): + + def setUp(self): + super().setUp() + flags.FLAGS.mark_as_parsed() + + def test_sac_params_config(self): + network_factory = functools.partial( + sac_networks.make_sac_networks, + hidden_layer_sizes=(16, 21, 13), + ) + config = checkpoint.network_config( + action_size=3, + observation_size=1, + normalize_observations=True, + network_factory=network_factory, + ) + self.assertEqual( + config.network_factory_kwargs.to_dict()["hidden_layer_sizes"], + (16, 21, 13), + ) + self.assertEqual(config.action_size, 3) + self.assertEqual(config.observation_size, 1) + + def test_save_and_load_checkpoint(self): + path = self.create_tempdir("test") + network_factory = functools.partial( + sac_networks.make_sac_networks, + hidden_layer_sizes=(16, 21, 13), + ) + config = checkpoint.network_config( + observation_size=1, + action_size=3, + normalize_observations=True, + network_factory=network_factory, + ) + + # Generate network params for saving a dummy checkpoint. + normalize = lambda x, y: x + if config.normalize_observations: + normalize = running_statistics.normalize + sac_network = network_factory( + config.observation_size, + config.action_size, + preprocess_observations_fn=normalize, + **config.network_factory_kwargs, + ) + dummy_key = jax.random.PRNGKey(0) + normalizer_params = running_statistics.init_state( + jax.tree_util.tree_map(jp.zeros, config.observation_size) + ) + params = (normalizer_params, sac_network.policy_network.init(dummy_key)) + + # Save and load a checkpoint. + checkpoint.save( + path.full_path, + step=1, + params=params, + config=config, + ) + + policy_fn = checkpoint.load_policy( + epath.Path(path.full_path) / "000000000001", + ) + out = policy_fn(jp.zeros(1), jax.random.PRNGKey(0)) + self.assertEqual(out[0].shape, (3,)) + + +if __name__ == "__main__": + absltest.main() diff --git a/brax/training/agents/sac/train.py b/brax/training/agents/sac/train.py index d2217b35..deef6e0c 100644 --- a/brax/training/agents/sac/train.py +++ b/brax/training/agents/sac/train.py @@ -32,6 +32,7 @@ from brax.training import types from brax.training.acme import running_statistics from brax.training.acme import specs +from brax.training.agents.sac import checkpoint from brax.training.agents.sac import losses as sac_losses from brax.training.agents.sac import networks as sac_networks from brax.training.types import Params @@ -137,11 +138,12 @@ def train( sac_networks.SACNetworks ] = sac_networks.make_sac_networks, progress_fn: Callable[[int, Metrics], None] = lambda *args: None, - checkpoint_logdir: Optional[str] = None, eval_env: Optional[envs.Env] = None, randomization_fn: Optional[ Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]] ] = None, + checkpoint_logdir: Optional[str] = None, + restore_checkpoint_path: Optional[str] = None, ): """SAC training.""" process_id = jax.process_index() @@ -258,6 +260,13 @@ def train( actor_loss, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME ) + ckpt_config = checkpoint.network_config( + observation_size=obs_size, + action_size=env.action_size, + normalize_observations=normalize_observations, + network_factory=network_factory, + ) + def sgd_step( carry: Tuple[TrainingState, PRNGKey], transitions: Transition ) -> Tuple[Tuple[TrainingState, PRNGKey], Metrics]: @@ -485,6 +494,13 @@ def training_epoch_with_timing( ) del global_key + if restore_checkpoint_path is not None: + params = checkpoint.load(restore_checkpoint_path) + training_state = training_state.replace( + normalizer_params=params[0], + policy_params=params[1], + ) + local_key, rb_key, env_key, eval_key = jax.random.split(local_key, 4) # Env init @@ -566,12 +582,10 @@ def training_epoch_with_timing( # Eval and logging if process_id == 0: if checkpoint_logdir: - # Save current policy. params = _unpmap( (training_state.normalizer_params, training_state.policy_params) ) - path = f'{checkpoint_logdir}_sac_{current_step}.pkl' - model.save_params(path, params) + checkpoint.save(checkpoint_logdir, current_step, params, ckpt_config) # Run evals. metrics = evaluator.run_evaluation( diff --git a/brax/training/checkpoint.py b/brax/training/checkpoint.py new file mode 100644 index 00000000..5ad60d26 --- /dev/null +++ b/brax/training/checkpoint.py @@ -0,0 +1,139 @@ +# Copyright 2024 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpointing functions.""" + +import inspect +import logging +from typing import Any, Dict, Tuple, Union + +from brax.training import types +from brax.training.acme import running_statistics +from brax.training.agents.ppo import networks as ppo_networks +from brax.training.agents.sac import networks as sac_networks +from etils import epath +from flax.training import orbax_utils +from ml_collections import config_dict +from orbax import checkpoint as ocp + + +def _get_function_kwargs(func: Any) -> Dict[str, Any]: + """Gets kwargs of a function.""" + return { + p.name: p.default + for p in inspect.signature(func).parameters.values() + if p.default is not inspect.Parameter.empty + } + + +def _get_function_defaults(func: Any) -> Dict[str, Any]: + """Gets default kwargs of a function potentially wrapped in partials.""" + kwargs = _get_function_kwargs(func) + if hasattr(func, 'func'): + kwargs.update(_get_function_defaults(func.func)) + return kwargs + + +def network_config( + observation_size: types.ObservationSize, + action_size: int, + normalize_observations: bool, + network_factory: types.NetworkFactory[ + Union[ppo_networks.PPONetworks, sac_networks.SACNetworks] + ], +) -> config_dict.ConfigDict: + """Returns a config dict for re-creating a network from a checkpoint.""" + config = config_dict.ConfigDict() + kwargs = _get_function_kwargs(network_factory) + defaults = _get_function_defaults(network_factory) + + if 'preprocess_observations_fn' in kwargs: + if ( + kwargs['preprocess_observations_fn'] + != defaults['preprocess_observations_fn'] + ): + raise ValueError( + 'checkpointing only supports identity_observation_preprocessor as the' + ' preprocess_observations_fn' + ) + del kwargs['preprocess_observations_fn'] + if 'activation' in kwargs: + if kwargs['activation'] != defaults['activation']: + raise ValueError('checkpointing only supports default activation') + del kwargs['activation'] + + config.network_factory_kwargs = kwargs + config.normalize_observations = normalize_observations + config.observation_size = observation_size + config.action_size = action_size + return config + + +def get_network( + config: config_dict.ConfigDict, + network_factory: types.NetworkFactory[ + Union[ppo_networks.PPONetworks, sac_networks.SACNetworks] + ], +) -> Union[ppo_networks.PPONetworks, sac_networks.SACNetworks]: + """Generates a network given config.""" + normalize = lambda x, y: x + if config.normalize_observations: + normalize = running_statistics.normalize + network = network_factory( + config.to_dict()['observation_size'], + config.action_size, + preprocess_observations_fn=normalize, + **config.network_factory_kwargs, + ) + return network + + +def save( + path: Union[str, epath.Path], + step: int, + params: Tuple[Any, ...], + config: config_dict.ConfigDict, + config_fname: str = 'config.json', +): + """Saves a checkpoint.""" + ckpt_path = epath.Path(path) / f'{step:012d}' + logging.info('saving checkpoint to %s', ckpt_path.as_posix()) + + if not ckpt_path.exists(): + ckpt_path.mkdir(parents=True) + + config_path = epath.Path(path) / config_fname + if not config_path.exists(): + config_path.write_text(config.to_json()) + + orbax_checkpointer = ocp.PyTreeCheckpointer() + save_args = orbax_utils.save_args_from_target(params) + orbax_checkpointer.save(ckpt_path, params, force=True, save_args=save_args) + + +def load( + path: Union[str, epath.Path], +): + """Loads checkpoint.""" + path = epath.Path(path) + if not path.exists(): + raise ValueError(f'checkpoint path does not exist: {path.as_posix()}') + + logging.info('restoring from checkpoint %s', path.as_posix()) + + orbax_checkpointer = ocp.PyTreeCheckpointer() + target = orbax_checkpointer.restore(path, item=None) + target[0] = running_statistics.RunningStatisticsState(**target[0]) + + return target