Skip to content

Commit

Permalink
Merge pull request #22 from NES-NN/version/0.0.8
Browse files Browse the repository at this point in the history
Updates and hardening emulator close functions
  • Loading branch information
ppaquette authored Jul 3, 2018
2 parents 6d426e8 + 9f9789e commit be8bca6
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 66 deletions.
33 changes: 0 additions & 33 deletions ppaquette_gym_super_mario/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from gym.envs.registration import register
from gym.scoreboard.registration import add_task, add_group
from .package_info import USERNAME
from .nes_env import NesEnv, MetaNesEnv
from .super_mario_bros import SuperMarioBrosEnv, MetaSuperMarioBrosEnv
Expand Down Expand Up @@ -40,35 +39,3 @@
# Seems to be non-deterministic about 5% of the time
nondeterministic=True,
)

# Scoreboard registration
# ==========================
add_group(
id= 'super-mario',
name= 'SuperMario',
description= '32 levels of the original Super Mario Bros game.'
)

add_task(
id='{}/meta-SuperMarioBros-v0'.format(USERNAME),
group='super-mario',
summary='Compilation of all 32 levels of Super Mario Bros. on Nintendo platform - Screen version.',
)
add_task(
id='{}/meta-SuperMarioBros-Tiles-v0'.format(USERNAME),
group='super-mario',
summary='Compilation of all 32 levels of Super Mario Bros. on Nintendo platform - Tiles version.',
)

for world in range(8):
for level in range(4):
add_task(
id='{}/SuperMarioBros-{}-{}-v0'.format(USERNAME, world + 1, level + 1),
group='super-mario',
summary='Level: {}-{} of Super Mario Bros. on Nintendo platform - Screen version.'.format(world + 1, level + 1),
)
add_task(
id='{}/SuperMarioBros-{}-{}-Tiles-v0'.format(USERNAME, world + 1, level + 1),
group='super-mario',
summary='Level: {}-{} of Super Mario Bros. on Nintendo platform - Tiles version.'.format(world + 1, level + 1),
)
62 changes: 33 additions & 29 deletions ppaquette_gym_super_mario/nes_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def __init__(self):
self.rom_path = ''
self.screen_height = 224
self.screen_width = 256
self.action_space = spaces.MultiDiscrete([[0, 1]] * NUM_ACTIONS)
self.action_space = spaces.MultiDiscrete([1] * NUM_ACTIONS)
self.observation_space = spaces.Box(low=0, high=255, shape=(self.screen_height, self.screen_width, 3))
self.launch_vars = {}
self.cmd_args = ['--xscale 2', '--yscale 2', '-f 0']
self.lua_path = []
self.subprocess = None
self.fceux_pid = None
self.no_render = True
self.viewer = None

Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(self):
self.first_step = False
self.lock = (NesLock()).get_lock()

self.temp_lua_path = ""

# Seeding
self.curr_seed = 0
self._seed()
Expand Down Expand Up @@ -190,8 +193,8 @@ def _launch_fceux(self):
self._create_pipes()

# Creating temporary lua file
temp_lua_path = os.path.join('/tmp', str(seeding.hash_seed(None) % 2 ** 32) + '.lua')
temp_lua_file = open(temp_lua_path, 'w', 1)
self.temp_lua_path = os.path.join('/tmp', str(seeding.hash_seed(None) % 2 ** 32) + '.lua')
temp_lua_file = open(self.temp_lua_path, 'w', 1)
for k, v in list(self.launch_vars.items()):
temp_lua_file.write('%s = "%s";\n' % (k, v))
i = 0
Expand All @@ -212,11 +215,13 @@ def _launch_fceux(self):
# Loading fceux
args = [FCEUX_PATH]
args.extend(self.cmd_args[:])
args.extend(['--loadlua', temp_lua_path])
args.extend(['--loadlua', self.temp_lua_path])
args.append(self.rom_path)
args.extend(['>/dev/null', '2>/dev/null', '&'])
self.subprocess = subprocess.Popen(' '.join(args), shell=True)
self.subprocess.communicate()
self.subprocess = subprocess.Popen("/bin/bash", shell=False, universal_newlines=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = self.subprocess.communicate(' '.join(args) + "\necho $!")
self.fceux_pid = int(stdout)

if 0 == self.subprocess.returncode:
self.is_initialized = 1
if not self.disable_out_pipe:
Expand All @@ -227,9 +232,9 @@ def _launch_fceux(self):
self.pipe_out = None
# Removing lua file
sleep(1) # Sleeping to make sure fceux has time to load file before removing
if os.path.isfile(temp_lua_path):
if os.path.isfile(self.temp_lua_path):
try:
os.remove(temp_lua_path)
os.remove(self.temp_lua_path)
except OSError:
pass
else:
Expand Down Expand Up @@ -264,7 +269,7 @@ def _get_info(self):
# Overridable - Returns the other variables
return self.info

def _step(self, action):
def step(self, action):
if 0 == self.is_initialized:
return self._get_state(), 0, self._get_is_finished(), {}

Expand Down Expand Up @@ -324,13 +329,7 @@ def _step(self, action):
# Game stuck, returning
# Likely caused by fceux incoming pipe not working
logger.warn('Closing episode (appears to be stuck). See documentation for how to handle this issue.')
if self.subprocess is not None:
# Workaround, killing process with pid + 1 (shell = pid, shell + 1 = fceux)
try:
os.kill(self.subprocess.pid + 1, signal.SIGKILL)
except OSError:
pass
self.subprocess = None
self._terminate_fceux()
return self._get_state(), 0, True, {'ignore': True}

# Getting results
Expand All @@ -340,7 +339,9 @@ def _step(self, action):
info = self._get_info()
return state, reward, is_finished, info

def _reset(self):
def reset(self):
self._terminate_fceux()

if 1 == self.is_initialized:
self.close()
self.last_frame = 0
Expand Down Expand Up @@ -376,18 +377,11 @@ def _render(self, mode='human', close=False):
self.viewer = rendering.SimpleImageViewer()
self.viewer.imshow(img)

def _close(self):
# Terminating thread
def close(self):
self.is_exiting = 1
self._write_to_pipe('exit')
sleep(0.05)
if self.subprocess is not None:
# Workaround, killing process with pid + 1 (shell = pid, shell + 1 = fceux)
try:
os.kill(self.subprocess.pid + 1, signal.SIGKILL)
except OSError:
pass
self.subprocess = None
self._terminate_fceux()
sleep(0.001)
self._close_pipes()
self.last_frame = 0
Expand All @@ -398,6 +392,16 @@ def _close(self):
self._reset_info_vars()
self.is_initialized = 0

def _terminate_fceux(self):
if self.subprocess is not None:
try:
os.kill(self.fceux_pid, signal.SIGKILL)
except OSError as e:
cmd = "kill -9 $(ps -ef | grep 'fceux' | grep " + self.temp_lua_path + " | awk '{print $2}')"
os.system(cmd)
pass
self.subprocess = None

def _seed(self, seed=None):
self.curr_seed = seeding.hash_seed(seed) % 256
return [self.curr_seed]
Expand Down Expand Up @@ -637,7 +641,7 @@ def get_scores(self):
averages[i] = round(level_average, 4)
return averages

def _reset(self):
def reset(self):
# Reset is called on first step() after level is finished
# or when change_level() is called. Returning if neither have been called to
# avoid resetting the level twice
Expand All @@ -656,12 +660,12 @@ def _reset(self):
self.screen = np.zeros(shape=(self.screen_height, self.screen_width, 3), dtype=np.uint8)
return self._get_state()

def _step(self, action):
def step(self, action):
# Changing level
if self.find_new_level:
self.change_level()

obs, step_reward, is_finished, info = NesEnv._step(self, action)
obs, step_reward, is_finished, info = NesEnv.step(self, action)
reward, self.total_reward = self._calculate_reward(self._get_episode_reward(), self.total_reward)
# First step() after new episode returns the entire total reward
# because stats_recorder resets the episode score to 0 after reset() is called
Expand Down
2 changes: 1 addition & 1 deletion ppaquette_gym_super_mario/package_info.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
VERSION='0.0.7'
VERSION='0.0.8'
USERNAME='ppaquette'
2 changes: 1 addition & 1 deletion requirements.txt
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1 @@
gym>=0.8.0
gym==0.10.5
5 changes: 3 additions & 2 deletions setup.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from setuptools import setup, find_packages
import sys, os

# Don't import gym module here, since deps may not be installed
for package in find_packages():
if '_gym_' in package:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), package))
Expand All @@ -17,5 +16,7 @@
packages=[package for package in find_packages() if package.startswith(USERNAME)],
package_data={ '{}_{}'.format(USERNAME, 'gym_super_mario'): ['lua/*.lua', 'roms/*.nes' ] },
zip_safe=False,
install_requires=[ 'gym>=0.8.0' ],
install_requires=[
'gym==0.10.5'
],
)

0 comments on commit be8bca6

Please sign in to comment.