Skip to content

Commit

Permalink
fix rollout episode termination logic (#1907)
Browse files Browse the repository at this point in the history
* Rollout stop collecting steps upon timeout

addresses #1906

* Update step type timeout logic

Co-authored-by: Eric Yihan Chen <[email protected]>
  • Loading branch information
avnishn and AiRuiChen authored Aug 17, 2020
1 parent 6282048 commit 95b8275
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/garage/_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,10 +491,10 @@ def get_step_type(cls, step_cnt, max_episode_length, done):
ValueError: if step_cnt is < 1. In this case a environment's
`reset()` is likely not called yet and the step_cnt is None.
"""
if done:
return StepType.TERMINAL
elif max_episode_length is not None and step_cnt >= max_episode_length:
if max_episode_length is not None and step_cnt >= max_episode_length:
return StepType.TIMEOUT
elif done:
return StepType.TERMINAL
elif step_cnt == 1:
return StepType.FIRST
elif step_cnt < 1:
Expand Down
2 changes: 1 addition & 1 deletion src/garage/sampler/default_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def step_episode(self):
self._agent_infos[k].append(v)
self._eps_length += 1

if not es.last:
if not es.terminal:
self._prev_obs = es.observation
return False
self._lengths.append(self._eps_length)
Expand Down
2 changes: 1 addition & 1 deletion src/garage/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def rollout(env,
observations.append(last_obs)
agent_infos.append(agent_info)
episode_length += 1
if es.terminal:
if es.last:
break
last_obs = es.observation

Expand Down
6 changes: 3 additions & 3 deletions tests/garage/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,11 @@ def test_get_step_type():
max_episode_length=5,
done=False)
assert step_type == StepType.TIMEOUT
step_type = StepType.get_step_type(step_cnt=1,
step_type = StepType.get_step_type(step_cnt=5,
max_episode_length=5,
done=True)
assert step_type == StepType.TERMINAL
step_type = StepType.get_step_type(step_cnt=5,
assert step_type == StepType.TIMEOUT
step_type = StepType.get_step_type(step_cnt=1,
max_episode_length=5,
done=True)
assert step_type == StepType.TERMINAL
Expand Down

0 comments on commit 95b8275

Please sign in to comment.