Skip to content

Commit

Permalink
Updated atari wrappers, fixed pre-commit (#781)
Browse files Browse the repository at this point in the history
This PR addresses #772 (updates Atari wrappers to work with new Gym API)
and some additional issues:

- Pre-commit was using gitlab for flake8, which as of recently requires
authentication -> Replaced with GitHub
- Yapf was quietly failing in pre-commit. Changed it such that it fixes
formatting in-place
- There is an incompatibility between flake8 and yapf where yapf puts
binary operators after the line break and flake8 wants it before the
break. I added an exception for flake8.
- Also require `packaging` in setup.py

My changes shouldn't change the behaviour of the wrappers for older
versions, but please double check.
Idk whether it's just me, but there are always some incompatibilities
between yapf and flake8 that need to resolved manually. It might make
sense to try black instead.
  • Loading branch information
Markus28 authored Dec 4, 2022
1 parent 662af52 commit 4c3791a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 19 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@ repos:
# pass_filenames: false
# args: [--config-file=setup.cfg, tianshou]

- repo: https://github.com/pre-commit/mirrors-yapf
- repo: https://github.com/google/yapf
rev: v0.32.0
hooks:
- id: yapf
args: [-r]
args: [-r, -i]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
hooks:
- id: isort
name: isort

- repo: https://gitlab.com/PyCQA/flake8
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
Expand Down
82 changes: 67 additions & 15 deletions examples/atari/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
envpool = None


def _parse_reset_result(reset_result):
contains_info = (
isinstance(reset_result, tuple) and len(reset_result) == 2
and isinstance(reset_result[1], dict)
)
if contains_info:
return reset_result[0], reset_result[1], contains_info
return reset_result, {}, contains_info


class NoopResetEnv(gym.Wrapper):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
Expand All @@ -30,16 +40,23 @@ def __init__(self, env, noop_max=30):
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

def reset(self):
self.env.reset()
def reset(self, **kwargs):
_, info, return_info = _parse_reset_result(self.env.reset(**kwargs))
if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
step_result = self.env.step(self.noop_action)
if len(step_result) == 4:
obs, rew, done, info = step_result
else:
obs, rew, term, trunc, info = step_result
done = term or trunc
if done:
obs = self.env.reset()
obs, info, _ = _parse_reset_result(self.env.reset())
if return_info:
return obs, info
return obs


Expand All @@ -59,14 +76,24 @@ def step(self, action):
"""Step the environment with the given action. Repeat action, sum
reward, and max over last observations.
"""
obs_list, total_reward, done = [], 0., False
obs_list, total_reward = [], 0.
new_step_api = False
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True
obs_list.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(obs_list[-2:], axis=0)
if new_step_api:
return max_frame, total_reward, term, trunc, info

return max_frame, total_reward, done, info


Expand All @@ -81,9 +108,18 @@ def __init__(self, env):
super().__init__(env)
self.lives = 0
self.was_real_done = True
self._return_info = False

def step(self, action):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
done = term or trunc
new_step_api = True

self.was_real_done = done
# check current lives, make loss of life terminal, then update lives to
# handle bonus lives
Expand All @@ -93,7 +129,10 @@ def step(self, action):
# frames, so its important to keep lives > 0, so that we only reset
# once the environment is actually done.
done = True
term = True
self.lives = lives
if new_step_api:
return obs, reward, term, trunc, info
return obs, reward, done, info

def reset(self):
Expand All @@ -102,12 +141,16 @@ def reset(self):
the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
obs, info, self._return_info = _parse_reset_result(self.env.reset())
else:
# no-op step to advance from terminal/lost life state
obs = self.env.step(0)[0]
step_result = self.env.step(0)
obs, info = step_result[0], step_result[-1]
self.lives = self.env.unwrapped.ale.lives()
return obs
if self._return_info:
return obs, info
else:
return obs


class FireResetEnv(gym.Wrapper):
Expand All @@ -123,8 +166,9 @@ def __init__(self, env):
assert len(env.unwrapped.get_action_meanings()) >= 3

def reset(self):
self.env.reset()
return self.env.step(1)[0]
_, _, return_info = _parse_reset_result(self.env.reset())
obs = self.env.step(1)[0]
return (obs, {}) if return_info else obs


class WarpFrame(gym.ObservationWrapper):
Expand Down Expand Up @@ -204,14 +248,22 @@ def __init__(self, env, n_frames):
)

def reset(self):
obs = self.env.reset()
obs, info, return_info = _parse_reset_result(self.env.reset())
for _ in range(self.n_frames):
self.frames.append(obs)
return self._get_ob()
return (self._get_ob(), info) if return_info else self._get_ob()

def step(self, action):
obs, reward, done, info = self.env.step(action)
step_result = self.env.step(action)
if len(step_result) == 4:
obs, reward, done, info = step_result
new_step_api = False
else:
obs, reward, term, trunc, info = step_result
new_step_api = True
self.frames.append(obs)
if new_step_api:
return self._get_ob(), reward, term, trunc, info
return self._get_ob(), reward, done, info

def _get_ob(self):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ exclude =
dist
*.egg-info
max-line-length = 87
ignore = B305,W504,B006,B008,B024
ignore = B305,W504,B006,B008,B024,W503

[yapf]
based_on_style = pep8
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def get_install_requires() -> str:
"numba>=0.51.0",
"h5py>=2.10.0", # to match tensorflow's minimal requirements
"protobuf~=3.19.0", # breaking change, sphinx fail
"packaging",
]


Expand Down

0 comments on commit 4c3791a

Please sign in to comment.