diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index f7a25c1bd5c..5b6763f6910 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -730,19 +730,20 @@ def _create_td(self) -> None: ) ) env_output_keys = env_output_keys.union(self.reward_keys + self.done_keys) - env_obs_keys = [ - key for key in env_obs_keys if key not in self._non_tensor_keys - ] - env_input_keys = [ - key for key in env_input_keys if key not in self._non_tensor_keys - ] - env_output_keys = [ - key for key in env_output_keys if key not in self._non_tensor_keys - ] self._env_obs_keys = sorted(env_obs_keys, key=_sort_keys) self._env_input_keys = sorted(env_input_keys, key=_sort_keys) self._env_output_keys = sorted(env_output_keys, key=_sort_keys) + self._env_obs_keys = [ + key for key in self._env_obs_keys if key not in self._non_tensor_keys + ] + self._env_input_keys = [ + key for key in self._env_input_keys if key not in self._non_tensor_keys + ] + self._env_output_keys = [ + key for key in self._env_output_keys if key not in self._non_tensor_keys + ] + reset_keys = self.reset_keys self._selected_keys = ( set(self._env_output_keys)