diff --git a/brax/training/agents/ppo/train_test.py b/brax/training/agents/ppo/train_test.py index 5392194c..f93445c8 100644 --- a/brax/training/agents/ppo/train_test.py +++ b/brax/training/agents/ppo/train_test.py @@ -135,8 +135,6 @@ def get_offset(rng): def testTrainAsymmetricActorCritic(self): """Test PPO with asymmetric actor critic.""" - key = jax.random.PRNGKey(0) - env = envs.get_environment('fast', asymmetric_obs=True, use_dict_obs=True) network_factory = functools.partial( @@ -146,40 +144,8 @@ def testTrainAsymmetricActorCritic(self): policy_obs_key='state', value_obs_key='privileged_state' ) - ppo_network = network_factory( - observation_size=env.observation_size, - action_size=env.action_size, - preprocess_observations_fn=lambda x, y: x, - ) - key, reset_key = jax.random.split(key) - state = env.reset(reset_key) - - key, key_policy = jax.random.split(key) - policy_params = ppo_network.policy_network.init(key_policy) - self.assertEqual( - policy_params['params']['hidden_0']['kernel'].shape, - (env.observation_size['state'], 32), - ) - ppo_network.policy_network.apply( - processor_params=None, - policy_params=policy_params, - obs=state.obs, - ) - - key, key_value = jax.random.split(key) - value_params = ppo_network.value_network.init(key_value) - self.assertEqual( - value_params['params']['hidden_0']['kernel'].shape, - (env.observation_size['privileged_state'], 32), - ) - ppo_network.value_network.apply( - processor_params=None, - value_params=value_params, - obs=state.obs, - ) - - _, _, _ = ppo.train( + _, (_, policy_params, value_params), _ = ppo.train( env, num_timesteps=2**15, episode_length=1000, @@ -198,6 +164,15 @@ def testTrainAsymmetricActorCritic(self): network_factory=network_factory, ) + self.assertEqual( + policy_params['params']['hidden_0']['kernel'].shape, + (env.observation_size['state'], 32), + ) + self.assertEqual( + value_params['params']['hidden_0']['kernel'].shape, + (env.observation_size['privileged_state'], 32), + ) + if __name__ == '__main__': absltest.main()