-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Practical Reinforcement Learning Week5
- Loading branch information
1 parent
e47da9e
commit d86a95e
Showing
7 changed files
with
1,402 additions
and
0 deletions.
There are no files selected for viewing
Binary file added
BIN
+605 KB
Practical Reinforcement Learning/Week5_policy_based/A policy-based quiz.pdf
Binary file not shown.
59 changes: 59 additions & 0 deletions
59
Practical Reinforcement Learning/Week5_policy_based/atari_util.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Auxilary files for those who wanted to solve breakout with CEM or policy gradient""" | ||
import numpy as np | ||
import gym | ||
from scipy.misc import imresize | ||
from gym.core import Wrapper | ||
from gym.spaces.box import Box | ||
|
||
class PreprocessAtari(Wrapper): | ||
def __init__(self, env, height=42, width=42, color=False, crop=lambda img: img, | ||
n_frames=4, dim_order='theano', reward_scale=1,): | ||
"""A gym wrapper that reshapes, crops and scales image into the desired shapes""" | ||
super(PreprocessAtari, self).__init__(env) | ||
assert dim_order in ('theano', 'tensorflow') | ||
self.img_size = (height, width) | ||
self.crop=crop | ||
self.color=color | ||
self.dim_order = dim_order | ||
self.reward_scale = reward_scale | ||
|
||
n_channels = (3 * n_frames) if color else n_frames | ||
obs_shape = [n_channels,height,width] if dim_order == 'theano' else [height,width,n_channels] | ||
self.observation_space = Box(0.0, 1.0, obs_shape) | ||
self.framebuffer = np.zeros(obs_shape, 'float32') | ||
|
||
def reset(self): | ||
"""resets breakout, returns initial frames""" | ||
self.framebuffer = np.zeros_like(self.framebuffer) | ||
self.update_buffer(self.env.reset()) | ||
return self.framebuffer | ||
|
||
def step(self,action): | ||
"""plays breakout for 1 step, returns frame buffer""" | ||
new_img, reward, done, info = self.env.step(action) | ||
self.update_buffer(new_img) | ||
return self.framebuffer, reward * self.reward_scale, done, info | ||
|
||
### image processing ### | ||
|
||
def update_buffer(self,img): | ||
img = self.preproc_image(img) | ||
offset = 3 if self.color else 1 | ||
if self.dim_order == 'theano': | ||
axis = 0 | ||
cropped_framebuffer = self.framebuffer[:-offset] | ||
else: | ||
axis = -1 | ||
cropped_framebuffer = self.framebuffer[:,:,:-offset] | ||
self.framebuffer = np.concatenate([img, cropped_framebuffer], axis = axis) | ||
|
||
def preproc_image(self, img): | ||
"""what happens to the observation""" | ||
img = self.crop(img) | ||
img = imresize(img, self.img_size) | ||
if not self.color: | ||
img = img.mean(-1, keepdims=True) | ||
if self.dim_order == 'theano': | ||
img = img.transpose([2,0,1]) # [h, w, c] to [c, h, w] | ||
img = img.astype('float32') / 255. | ||
return img |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
873 changes: 873 additions & 0 deletions
873
Practical Reinforcement Learning/Week5_policy_based/practice_a3c.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
Oops, something went wrong.