Skip to content

Commit

Permalink
Merge branch 'main' into vision_ppo_rebased
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew-Luo1 authored Dec 4, 2024
2 parents cd7aa15 + 417465c commit 692d626
Show file tree
Hide file tree
Showing 13 changed files with 28 additions and 22 deletions.
4 changes: 3 additions & 1 deletion brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
import jax
import numpy as np

Observation = Union[jax.Array, Mapping[str, jax.Array]]
ObservationSize = Union[int, Mapping[str, Union[Tuple[int, ...], int]]]


@struct.dataclass
class State(base.Base):
"""Environment state for training and inference."""

pipeline_state: Optional[base.State]
obs: Union[jax.Array, Mapping[str, jax.Array]]
obs: Observation
reward: jax.Array
done: jax.Array
metrics: Dict[str, jax.Array] = struct.field(default_factory=dict)
Expand Down
3 changes: 2 additions & 1 deletion brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
raise ValueError("asymmetric_obs requires dictionary observations")

def reset(self, rng: jax.Array) -> State:
del rng # Unused.
self._reset_count += 1
pipeline_state = base.State(
q=jp.zeros(1),
Expand Down Expand Up @@ -94,6 +93,7 @@ def step(self, state: State, action: jax.Array) -> State:
self._step_count += 1
vel = state.pipeline_state.xd.vel + (action > 0) * self._dt
pos = state.pipeline_state.x.pos + vel * self._dt

qp = state.pipeline_state.replace(
x=state.pipeline_state.x.replace(pos=pos),
xd=state.pipeline_state.xd.replace(vel=vel),
Expand All @@ -116,6 +116,7 @@ def step(self, state: State, action: jax.Array) -> State:
obs = obs["state"]

reward = pos[0]

return state.replace(pipeline_state=qp, obs=obs, reward=reward)

@property
Expand Down
3 changes: 2 additions & 1 deletion brax/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def render_array(
def get_image(state: base.State):
d = mujoco.MjData(sys.mj_model)
d.qpos, d.qvel = state.q, state.qd
d.mocap_pos, d.mocap_quat = state.mocap_pos, state.mocap_quat
if hasattr(state, 'mocap_pos') and hasattr(state, 'mocap_quat'):
d.mocap_pos, d.mocap_quat = state.mocap_pos, state.mocap_quat
mujoco.mj_forward(sys.mj_model, d)
renderer.update_scene(d, camera=camera)
return renderer.render()
Expand Down
2 changes: 1 addition & 1 deletion brax/training/agents/apg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def train(

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in APG")
raise NotImplementedError('Dictionary observations not implemented in APG')

normalize = lambda x, y: x
if normalize_observations:
Expand Down
2 changes: 1 addition & 1 deletion brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def train(

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in ARS")
raise NotImplementedError('Dictionary observations not implemented in ARS')

normalize_fn = lambda x, y: x
if normalize_observations:
Expand Down
4 changes: 2 additions & 2 deletions brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def train(

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in ES")
raise NotImplementedError('Dictionary observations not implemented in ES')

normalize_fn = lambda x, y: x
if normalize_observations:
normalize_fn = running_statistics.normalize
Expand Down
3 changes: 1 addition & 2 deletions brax/training/agents/ppo/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ def compute_ppo_loss(

baseline = value_apply(normalizer_params, params.value, data.observation)
terminal_obs = jax.tree_util.tree_map(lambda x: x[-1], data.next_observation)
bootstrap_value = value_apply(normalizer_params, params.value,
terminal_obs)
bootstrap_value = value_apply(normalizer_params, params.value, terminal_obs)

rewards = data.reward * reward_scaling
truncation = data.extras['state_extras']['truncation']
Expand Down
1 change: 1 addition & 0 deletions brax/training/agents/ppo/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def make_ppo_networks(
hidden_layer_sizes=value_hidden_layer_sizes,
activation=activation,
obs_key=value_obs_key)

return PPONetworks(
policy_network=policy_network,
value_network=value_network,
Expand Down
3 changes: 1 addition & 2 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,6 @@ def train(
key_envs = jnp.reshape(key_envs,
(local_devices_to_use, -1) + key_envs.shape[1:])
env_state = reset_fn(key_envs)

# Discard the batch axes over devices and envs.
obs_shape = jax.tree_util.tree_map(lambda x: x.shape[2:], env_state.obs)

Expand Down Expand Up @@ -458,7 +457,7 @@ def training_epoch_with_timing(
)

obs_shape = jax.tree_util.tree_map(
lambda x: specs.Array(x.shape[-1:], jnp.dtype('float32')), env_state.obs
lambda x: specs.Array(x.shape[-1:], jnp.dtype('float32')), env_state.obs
)
training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray
optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars
Expand Down
3 changes: 2 additions & 1 deletion brax/training/agents/ppo/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def testTrainV2(self):
normalize_observations=True,
seed=2,
reward_scaling=10,
normalize_advantage=False)
normalize_advantage=False,
)

@parameterized.parameters(True, False)
def testNetworkEncoding(self, normalize_observations):
Expand Down
3 changes: 1 addition & 2 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ def train(

obs_size = env.observation_size
if isinstance(obs_size, Dict):
raise NotImplementedError("Dictionary observations not implemented in SAC")

raise NotImplementedError('Dictionary observations not implemented in SAC')
action_size = env.action_size

normalize_fn = lambda x, y: x
Expand Down
16 changes: 8 additions & 8 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ def ln_per_chan(v: jax.Array):
)(hidden)


def get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int:
obs_size = obs_size[obs_key] if isinstance(obs_size, Mapping) else obs_size
return jax.tree_util.tree_flatten(obs_size)[0][-1] # Size can be tuple or int.
def _get_obs_state_size(obs_size: types.ObservationSize, obs_key: str) -> int:
obs_size = obs_size[obs_key] if isinstance(obs_size, Mapping) else obs_size
return jax.tree_util.tree_flatten(obs_size)[0][-1]

def make_policy_network(
param_size: int,
Expand All @@ -189,10 +189,10 @@ def make_policy_network(

def apply(processor_params, policy_params, obs):
obs = preprocess_observations_fn(obs, processor_params)
obs = obs if isinstance(obs, jnp.ndarray) else obs[obs_key]
obs = obs if isinstance(obs, jax.Array) else obs[obs_key]
return policy_module.apply(policy_params, obs)

obs_size = get_obs_state_size(obs_size, obs_key)
obs_size = _get_obs_state_size(obs_size, obs_key)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: policy_module.init(key, dummy_obs), apply=apply)
Expand All @@ -213,10 +213,10 @@ def make_value_network(

def apply(processor_params, value_params, obs):
obs = preprocess_observations_fn(obs, processor_params)
obs = obs if isinstance(obs, jnp.ndarray) else obs[obs_key]
obs = obs if isinstance(obs, jax.Array) else obs[obs_key]
return jnp.squeeze(value_module.apply(value_params, obs), axis=-1)

obs_size = get_obs_state_size(obs_size, obs_key)
obs_size = _get_obs_state_size(obs_size, obs_key)
dummy_obs = jnp.zeros((1, obs_size))
return FeedForwardNetwork(
init=lambda key: value_module.init(key, dummy_obs), apply=apply)
Expand Down Expand Up @@ -300,7 +300,7 @@ def apply(processor_params, policy_params, obs):


def make_q_network(
obs_size: int,
obs_size: types.ObservationSize,
action_size: int,
preprocess_observations_fn: types.PreprocessObservationFn = types
.identity_observation_preprocessor,
Expand Down
3 changes: 3 additions & 0 deletions docs/release-notes/next-release.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@
* Change PPO train function to return both value and policy network params, rather than just policy params.
* Merge https://github.com/google/brax/pull/561, adds grad norm clipping to PPO.
* Merge https://github.com/google/brax/issues/477, changes pusher vel damping.
* Merge https://github.com/google/brax/pull/558, adds `mocap_pos` and `mocap_quat` to render function.
* Merge https://github.com/google/brax/pull/559, allows for dictionary observations environment `State`.
* Merge https://github.com/google/brax/pull/562, which supports asymmetric actor-critic for PPO.

0 comments on commit 692d626

Please sign in to comment.