From 8cd166012db50e93738b2b26fa898666f9fab42f Mon Sep 17 00:00:00 2001 From: Kevin Zakka Date: Sun, 1 Dec 2024 18:21:06 -0800 Subject: [PATCH] Better env implementation. --- brax/envs/fast.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/brax/envs/fast.py b/brax/envs/fast.py index 9f4b3164..648b8883 100644 --- a/brax/envs/fast.py +++ b/brax/envs/fast.py @@ -34,7 +34,10 @@ def __init__(self, **kwargs): raise ValueError('asymmetric_obs requires use_dict_obs=True') def _get_obs(self): - obs = {'state': jp.zeros(2)} if self._use_dict_obs else jp.zeros(2) + if not self._use_dict_obs: + return jp.zeros(2) + + obs = {'state': jp.zeros(2)} if self._asymmetric_obs: obs['privileged_state'] = jp.zeros(4) return obs @@ -78,12 +81,13 @@ def step_count(self): @property def observation_size(self): - if self._use_dict_obs: - size = {'state': 2} - if self._asymmetric_obs: - size['privileged_state'] = 4 - return size - return 2 + if not self._use_dict_obs: + return 2 + + obs = {'state': 2} + if self._asymmetric_obs: + obs['privileged_state'] = 4 + return obs @property def action_size(self):