Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723378503
Change-Id: I1b63d2f9dc5f49009bbafe4875baadfe51ed9afd
  • Loading branch information
Brax Team authored and btaba committed Feb 5, 2025
1 parent 296184a commit 241f9bc
Show file tree
Hide file tree
Showing 7 changed files with 379 additions and 89 deletions.
102 changes: 20 additions & 82 deletions brax/training/agents/ppo/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,115 +14,53 @@

"""Checkpointing for PPO."""

import inspect
import json
import logging
from typing import Any, Dict, Tuple, Union
from typing import Any, Union

from brax.training import checkpoint
from brax.training import types
from brax.training.acme import running_statistics
from brax.training.agents.ppo import networks as ppo_networks
from etils import epath
from flax import linen
from flax.training import orbax_utils
from ml_collections import config_dict
from orbax import checkpoint as ocp

_CONFIG_FNAME = 'config.json'


def _get_default_kwargs(func: Any) -> Dict[str, Any]:
"""Returns the default kwargs of a function."""
return {
p.name: p.default
for p in inspect.signature(func).parameters.values()
if p.default is not inspect.Parameter.empty
}


def ppo_config(
observation_size: types.ObservationSize,
action_size: int,
normalize_observations: bool,
network_factory: types.NetworkFactory[ppo_networks.PPONetworks],
) -> config_dict.ConfigDict:
"""Returns a config dict for re-creating PPO params from a checkpoint."""
config = config_dict.ConfigDict()
kwargs = _get_default_kwargs(network_factory)

if (
kwargs.get('preprocess_observations_fn')
!= types.identity_observation_preprocessor
):
raise ValueError(
'preprocess_observations_fn must be identity_observation_preprocessor'
)
del kwargs['preprocess_observations_fn']
if kwargs.get('activation') != linen.swish:
raise ValueError('activation must be swish')
del kwargs['activation']

config.network_factory_kwargs = kwargs
config.normalize_observations = normalize_observations
config.observation_size = observation_size
config.action_size = action_size
return config
_CONFIG_FNAME = 'ppo_network_config.json'


def save(
path: Union[str, epath.Path],
step: int,
params: Tuple[Any, ...],
params: Any,
config: config_dict.ConfigDict,
):
"""Saves a checkpoint."""
ckpt_path = epath.Path(path) / f'{step:012d}'
logging.info('saving checkpoint to %s', ckpt_path.as_posix())

if not ckpt_path.exists():
ckpt_path.mkdir(parents=True)

config_path = epath.Path(path) / _CONFIG_FNAME
if not config_path.exists():
config_path.write_text(config.to_json())

orbax_checkpointer = ocp.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(params)
orbax_checkpointer.save(ckpt_path, params, force=True, save_args=save_args)
return checkpoint.save(path, step, params, config, _CONFIG_FNAME)


def load(
path: Union[str, epath.Path],
):
"""Loads PPO checkpoint."""
path = epath.Path(path)
if not path.exists():
raise ValueError(f'PPO checkpoint path does not exist: {path.as_posix()}')
"""Loads checkpoint."""
return checkpoint.load(path)

logging.info('restoring from checkpoint %s', path.as_posix())

orbax_checkpointer = ocp.PyTreeCheckpointer()
target = orbax_checkpointer.restore(path, item=None)
target[0] = running_statistics.RunningStatisticsState(**target[0])

return target
def network_config(
observation_size: types.ObservationSize,
action_size: int,
normalize_observations: bool,
network_factory: types.NetworkFactory[Union[ppo_networks.PPONetworks]],
) -> config_dict.ConfigDict:
"""Returns a config dict for re-creating a network from a checkpoint."""
return checkpoint.network_config(
observation_size, action_size, normalize_observations, network_factory
)


def _get_network(
def _get_ppo_network(
config: config_dict.ConfigDict,
network_factory: types.NetworkFactory[ppo_networks.PPONetworks],
) -> ppo_networks.PPONetworks:
"""Generates a PPO network given config."""
normalize = lambda x, y: x
if config.normalize_observations:
normalize = running_statistics.normalize
ppo_network = network_factory(
config.to_dict()['observation_size'],
config.action_size,
preprocess_observations_fn=normalize,
**config.network_factory_kwargs,
)
return ppo_network
return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type


def load_policy(
Expand All @@ -142,7 +80,7 @@ def load_policy(
config = config_dict.create(**json.loads(config_path.read_text()))

params = load(path)
ppo_network = _get_network(config, network_factory)
ppo_network = _get_ppo_network(config, network_factory)
make_inference_fn = ppo_networks.make_inference_fn(ppo_network)

return make_inference_fn(params, deterministic=deterministic)
4 changes: 2 additions & 2 deletions brax/training/agents/ppo/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_ppo_params_config(self):
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(16, 21, 13),
)
config = checkpoint.ppo_config(
config = checkpoint.network_config(
action_size=3,
observation_size=1,
normalize_observations=True,
Expand All @@ -57,7 +57,7 @@ def test_save_and_load_checkpoint(self):
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(16, 21, 13),
)
config = checkpoint.ppo_config(
config = checkpoint.network_config(
observation_size=1,
action_size=3,
normalize_observations=True,
Expand Down
16 changes: 15 additions & 1 deletion brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def train(
# checkpointing
save_checkpoint_path: Optional[str] = None,
restore_checkpoint_path: Optional[str] = None,
restore_params: Optional[Any] = None,
restore_value_fn: bool = True,
):
"""PPO training.
Expand Down Expand Up @@ -306,6 +307,9 @@ def train(
save_checkpoint_path: the path used to save checkpoints. If None, no
checkpoints are saved.
restore_checkpoint_path: the path used to restore previous model params
restore_params: raw network parameters to restore the TrainingState from.
These override `restore_checkpoint_path`. These paramaters can be obtained
from the return values of ppo.train().
restore_value_fn: whether to restore the value function from the checkpoint
or use a random initialization
Expand Down Expand Up @@ -422,7 +426,7 @@ def train(
progress_fn=progress_fn,
)

ckpt_config = checkpoint.ppo_config(
ckpt_config = checkpoint.network_config(
observation_size=obs_shape,
action_size=env.action_size,
normalize_observations=normalize_observations,
Expand Down Expand Up @@ -619,6 +623,16 @@ def training_epoch_with_timing(
),
)

if restore_params is not None:
logging.info('Restoring TrainingState from `restore_params`.')
value_params = restore_params[2] if restore_value_fn else init_params.value
training_state = training_state.replace(
normalizer_params=restore_params[0],
params=training_state.params.replace(
policy=restore_params[1], value=value_params
),
)

if num_timesteps == 0:
return (
make_policy,
Expand Down
86 changes: 86 additions & 0 deletions brax/training/agents/sac/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Checkpointing for SAC."""

import json
from typing import Any, Union

from brax.training import checkpoint
from brax.training import types
from brax.training.agents.sac import networks as sac_networks
from etils import epath
from ml_collections import config_dict

_CONFIG_FNAME = 'sac_network_config.json'


def save(
path: Union[str, epath.Path],
step: int,
params: Any,
config: config_dict.ConfigDict,
):
"""Saves a checkpoint."""
return checkpoint.save(path, step, params, config, _CONFIG_FNAME)


def load(
path: Union[str, epath.Path],
):
"""Loads SAC checkpoint."""
return checkpoint.load(path)


def network_config(
observation_size: types.ObservationSize,
action_size: int,
normalize_observations: bool,
network_factory: types.NetworkFactory[sac_networks.SACNetworks],
) -> config_dict.ConfigDict:
"""Returns a config dict for re-creating a network from a checkpoint."""
return checkpoint.network_config(
observation_size, action_size, normalize_observations, network_factory
)


def _get_network(
config: config_dict.ConfigDict,
network_factory: types.NetworkFactory[sac_networks.SACNetworks],
) -> sac_networks.SACNetworks:
"""Generates a SAC network given config."""
return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type


def load_policy(
path: Union[str, epath.Path],
network_factory: types.NetworkFactory[
sac_networks.SACNetworks
] = sac_networks.make_sac_networks,
deterministic: bool = True,
):
"""Loads policy inference function from SAC checkpoint."""
path = epath.Path(path)

config_path = path.parent / _CONFIG_FNAME
if not config_path.exists():
raise ValueError(f'SAC config file not found at {config_path.as_posix()}')

config = config_dict.create(**json.loads(config_path.read_text()))

params = load(path)
sac_network = _get_network(config, network_factory)
make_inference_fn = sac_networks.make_inference_fn(sac_network)

return make_inference_fn(params, deterministic=deterministic)
99 changes: 99 additions & 0 deletions brax/training/agents/sac/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2024 The Brax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test PPO checkpointing."""

import functools

from absl import flags
from absl.testing import absltest
from brax.training.acme import running_statistics
from brax.training.agents.sac import checkpoint
from brax.training.agents.sac import losses as sac_losses
from brax.training.agents.sac import networks as sac_networks
from etils import epath
import jax
from jax import numpy as jp


class CheckpointTest(absltest.TestCase):

def setUp(self):
super().setUp()
flags.FLAGS.mark_as_parsed()

def test_sac_params_config(self):
network_factory = functools.partial(
sac_networks.make_sac_networks,
hidden_layer_sizes=(16, 21, 13),
)
config = checkpoint.network_config(
action_size=3,
observation_size=1,
normalize_observations=True,
network_factory=network_factory,
)
self.assertEqual(
config.network_factory_kwargs.to_dict()["hidden_layer_sizes"],
(16, 21, 13),
)
self.assertEqual(config.action_size, 3)
self.assertEqual(config.observation_size, 1)

def test_save_and_load_checkpoint(self):
path = self.create_tempdir("test")
network_factory = functools.partial(
sac_networks.make_sac_networks,
hidden_layer_sizes=(16, 21, 13),
)
config = checkpoint.network_config(
observation_size=1,
action_size=3,
normalize_observations=True,
network_factory=network_factory,
)

# Generate network params for saving a dummy checkpoint.
normalize = lambda x, y: x
if config.normalize_observations:
normalize = running_statistics.normalize
sac_network = network_factory(
config.observation_size,
config.action_size,
preprocess_observations_fn=normalize,
**config.network_factory_kwargs,
)
dummy_key = jax.random.PRNGKey(0)
normalizer_params = running_statistics.init_state(
jax.tree_util.tree_map(jp.zeros, config.observation_size)
)
params = (normalizer_params, sac_network.policy_network.init(dummy_key))

# Save and load a checkpoint.
checkpoint.save(
path.full_path,
step=1,
params=params,
config=config,
)

policy_fn = checkpoint.load_policy(
epath.Path(path.full_path) / "000000000001",
)
out = policy_fn(jp.zeros(1), jax.random.PRNGKey(0))
self.assertEqual(out[0].shape, (3,))


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 241f9bc

Please sign in to comment.