Skip to content

Commit

Permalink
Use the params returned by the train function.
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzakka committed Dec 2, 2024
1 parent 8cd1660 commit 4c8c912
Showing 1 changed file with 10 additions and 35 deletions.
45 changes: 10 additions & 35 deletions brax/training/agents/ppo/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@ def get_offset(rng):

def testTrainAsymmetricActorCritic(self):
"""Test PPO with asymmetric actor critic."""
key = jax.random.PRNGKey(0)

env = envs.get_environment('fast', asymmetric_obs=True, use_dict_obs=True)

network_factory = functools.partial(
Expand All @@ -146,40 +144,8 @@ def testTrainAsymmetricActorCritic(self):
policy_obs_key='state',
value_obs_key='privileged_state'
)
ppo_network = network_factory(
observation_size=env.observation_size,
action_size=env.action_size,
preprocess_observations_fn=lambda x, y: x,
)

key, reset_key = jax.random.split(key)
state = env.reset(reset_key)

key, key_policy = jax.random.split(key)
policy_params = ppo_network.policy_network.init(key_policy)
self.assertEqual(
policy_params['params']['hidden_0']['kernel'].shape,
(env.observation_size['state'], 32),
)
ppo_network.policy_network.apply(
processor_params=None,
policy_params=policy_params,
obs=state.obs,
)

key, key_value = jax.random.split(key)
value_params = ppo_network.value_network.init(key_value)
self.assertEqual(
value_params['params']['hidden_0']['kernel'].shape,
(env.observation_size['privileged_state'], 32),
)
ppo_network.value_network.apply(
processor_params=None,
value_params=value_params,
obs=state.obs,
)

_, _, _ = ppo.train(
_, (_, policy_params, value_params), _ = ppo.train(
env,
num_timesteps=2**15,
episode_length=1000,
Expand All @@ -198,6 +164,15 @@ def testTrainAsymmetricActorCritic(self):
network_factory=network_factory,
)

self.assertEqual(
policy_params['params']['hidden_0']['kernel'].shape,
(env.observation_size['state'], 32),
)
self.assertEqual(
value_params['params']['hidden_0']['kernel'].shape,
(env.observation_size['privileged_state'], 32),
)


if __name__ == '__main__':
absltest.main()

0 comments on commit 4c8c912

Please sign in to comment.