diff --git a/src/garage/_dtypes.py b/src/garage/_dtypes.py index a723b23480..f106e09203 100644 --- a/src/garage/_dtypes.py +++ b/src/garage/_dtypes.py @@ -491,10 +491,10 @@ def get_step_type(cls, step_cnt, max_episode_length, done): ValueError: if step_cnt is < 1. In this case a environment's `reset()` is likely not called yet and the step_cnt is None. """ - if done: - return StepType.TERMINAL - elif max_episode_length is not None and step_cnt >= max_episode_length: + if max_episode_length is not None and step_cnt >= max_episode_length: return StepType.TIMEOUT + elif done: + return StepType.TERMINAL elif step_cnt == 1: return StepType.FIRST elif step_cnt < 1: diff --git a/src/garage/sampler/default_worker.py b/src/garage/sampler/default_worker.py index a8ae8af68c..d36a44f3cd 100644 --- a/src/garage/sampler/default_worker.py +++ b/src/garage/sampler/default_worker.py @@ -110,7 +110,7 @@ def step_episode(self): self._agent_infos[k].append(v) self._eps_length += 1 - if not es.last: + if not es.terminal: self._prev_obs = es.observation return False self._lengths.append(self._eps_length) diff --git a/src/garage/sampler/utils.py b/src/garage/sampler/utils.py index 98ef124a63..243420cab3 100644 --- a/src/garage/sampler/utils.py +++ b/src/garage/sampler/utils.py @@ -63,7 +63,7 @@ def rollout(env, observations.append(last_obs) agent_infos.append(agent_info) episode_length += 1 - if es.terminal: + if es.last: break last_obs = es.observation diff --git a/tests/garage/test_dtypes.py b/tests/garage/test_dtypes.py index 37d00fd5a3..5f6dc1d20c 100644 --- a/tests/garage/test_dtypes.py +++ b/tests/garage/test_dtypes.py @@ -236,11 +236,11 @@ def test_get_step_type(): max_episode_length=5, done=False) assert step_type == StepType.TIMEOUT - step_type = StepType.get_step_type(step_cnt=1, + step_type = StepType.get_step_type(step_cnt=5, max_episode_length=5, done=True) - assert step_type == StepType.TERMINAL - step_type = StepType.get_step_type(step_cnt=5, + assert step_type == StepType.TIMEOUT + step_type = StepType.get_step_type(step_cnt=1, max_episode_length=5, done=True) assert step_type == StepType.TERMINAL