diff --git a/Practical Reinforcement Learning/Week5_policy_based/A policy-based quiz.pdf b/Practical Reinforcement Learning/Week5_policy_based/A policy-based quiz.pdf new file mode 100644 index 0000000..777de77 Binary files /dev/null and b/Practical Reinforcement Learning/Week5_policy_based/A policy-based quiz.pdf differ diff --git a/Practical Reinforcement Learning/Week5_policy_based/atari_util.py b/Practical Reinforcement Learning/Week5_policy_based/atari_util.py new file mode 100644 index 0000000..0f6dd2a --- /dev/null +++ b/Practical Reinforcement Learning/Week5_policy_based/atari_util.py @@ -0,0 +1,59 @@ +"""Auxilary files for those who wanted to solve breakout with CEM or policy gradient""" +import numpy as np +import gym +from scipy.misc import imresize +from gym.core import Wrapper +from gym.spaces.box import Box + +class PreprocessAtari(Wrapper): + def __init__(self, env, height=42, width=42, color=False, crop=lambda img: img, + n_frames=4, dim_order='theano', reward_scale=1,): + """A gym wrapper that reshapes, crops and scales image into the desired shapes""" + super(PreprocessAtari, self).__init__(env) + assert dim_order in ('theano', 'tensorflow') + self.img_size = (height, width) + self.crop=crop + self.color=color + self.dim_order = dim_order + self.reward_scale = reward_scale + + n_channels = (3 * n_frames) if color else n_frames + obs_shape = [n_channels,height,width] if dim_order == 'theano' else [height,width,n_channels] + self.observation_space = Box(0.0, 1.0, obs_shape) + self.framebuffer = np.zeros(obs_shape, 'float32') + + def reset(self): + """resets breakout, returns initial frames""" + self.framebuffer = np.zeros_like(self.framebuffer) + self.update_buffer(self.env.reset()) + return self.framebuffer + + def step(self,action): + """plays breakout for 1 step, returns frame buffer""" + new_img, reward, done, info = self.env.step(action) + self.update_buffer(new_img) + return self.framebuffer, reward * self.reward_scale, done, info + + ### image processing ### + + def update_buffer(self,img): + img = self.preproc_image(img) + offset = 3 if self.color else 1 + if self.dim_order == 'theano': + axis = 0 + cropped_framebuffer = self.framebuffer[:-offset] + else: + axis = -1 + cropped_framebuffer = self.framebuffer[:,:,:-offset] + self.framebuffer = np.concatenate([img, cropped_framebuffer], axis = axis) + + def preproc_image(self, img): + """what happens to the observation""" + img = self.crop(img) + img = imresize(img, self.img_size) + if not self.color: + img = img.mean(-1, keepdims=True) + if self.dim_order == 'theano': + img = img.transpose([2,0,1]) # [h, w, c] to [c, h, w] + img = img.astype('float32') / 255. + return img diff --git a/Practical Reinforcement Learning/Week5_policy_based/env_pool.png b/Practical Reinforcement Learning/Week5_policy_based/env_pool.png new file mode 100644 index 0000000..a03e579 Binary files /dev/null and b/Practical Reinforcement Learning/Week5_policy_based/env_pool.png differ diff --git a/Practical Reinforcement Learning/Week5_policy_based/nnet_arch.png b/Practical Reinforcement Learning/Week5_policy_based/nnet_arch.png new file mode 100644 index 0000000..ee27fba Binary files /dev/null and b/Practical Reinforcement Learning/Week5_policy_based/nnet_arch.png differ diff --git a/Practical Reinforcement Learning/Week5_policy_based/practice_a3c.ipynb b/Practical Reinforcement Learning/Week5_policy_based/practice_a3c.ipynb new file mode 100644 index 0000000..f5379ab --- /dev/null +++ b/Practical Reinforcement Learning/Week5_policy_based/practice_a3c.ipynb @@ -0,0 +1,873 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Deep Kung-Fu with advantage actor-critic\n", + "\n", + "In this notebook you'll build a deep reinforcement learning agent for atari [KungFuMaster](https://gym.openai.com/envs/KungFuMaster-v0/) and train it with advantage actor-critic.\n", + "\n", + "![http://www.retroland.com/wp-content/uploads/2011/07/King-Fu-Master.jpg](http://www.retroland.com/wp-content/uploads/2011/07/King-Fu-Master.jpg)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import print_function, division\n", + "from IPython.core import display\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "import numpy as np\n", + "\n", + "#If you are running on a server, launch xvfb to record game videos\n", + "#Please make sure you have xvfb installed\n", + "import os\n", + "if os.environ.get(\"DISPLAY\") is str and len(os.environ.get(\"DISPLAY\"))!=0:\n", + " !bash ../xvfb start\n", + " %env DISPLAY=:1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For starters, let's take a look at the game itself:\n", + "* Image resized to 42x42 and grayscale to run faster\n", + "* Rewards divided by 100 'cuz they are all divisible by 100\n", + "* Agent sees last 4 frames of game to account for object velocity" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "Observation shape: (42, 42, 4)\n", + "Num actions: 14\n", + "Action names: ['NOOP', 'UP', 'RIGHT', 'LEFT', 'DOWN', 'DOWNRIGHT', 'DOWNLEFT', 'RIGHTFIRE', 'LEFTFIRE', 'DOWNFIRE', 'UPRIGHTFIRE', 'UPLEFTFIRE', 'DOWNRIGHTFIRE', 'DOWNLEFTFIRE']\n" + ] + } + ], + "source": [ + "import gym\n", + "from atari_util import PreprocessAtari\n", + "\n", + "def make_env():\n", + " env = gym.make(\"KungFuMasterDeterministic-v0\")\n", + " env = PreprocessAtari(env, height=42, width=42,\n", + " crop = lambda img: img[60:-30, 5:],\n", + " dim_order = 'tensorflow',\n", + " color=False, n_frames=4,\n", + " reward_scale = 0.01)\n", + " return env\n", + "\n", + "env = make_env()\n", + "\n", + "obs_shape = env.observation_space.shape\n", + "n_actions = env.action_space.n\n", + "\n", + "print(\"Observation shape:\", obs_shape)\n", + "print(\"Num actions:\", n_actions)\n", + "print(\"Action names:\", env.env.env.get_action_meanings())" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAANEAAAEICAYAAADBfBG8AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAFmVJREFUeJzt3XvUHHV9x/H3hyBoASHcEgi3wAGO4CVGxFTKRbyFVAXaqsFWUWkJlVA80FMIKFLUAirQKBUImnIRQSqi1BNQCnhpEeRiCJcIJIAQckMIBAVpE7/9Y2Zhstl9nnl2dp+Z2f28ztmzszOzu99J5ru/3/xmnu8oIjCzzm1QdgBmdeckMivISWRWkJPIrCAnkVlBTiKzgpxEfUjSTpJ+J2lM2bEMAidRAZKmS7pd0u8lrUynPyVJZcYVEY9HxKYRsbbMOAaFk6hDkk4EZgNfBsYD44BjgP2AjUoMzUZbRPgxwgewOfB74C+HWe/PgV8Bq4EngNMzy3YBAvhEumwVSRK+FVgAPAuc3/R5nwQWpuv+CNi5zfc2PnvD9PVPgC8AtwK/A/4T2Aq4Io3tDmCXzPtnpzGtBu4C9s8sew1waRrDQuCfgCWZ5dsD1wBPAY8C/1D2/1fP94eyA6jjA5gKrGnspEOsdxDwBpIW/43ACuCwdFljR78QeDXwHuAPwPeBbYEJwErgwHT9w4BFwOuADYHPALe2+d5WSbQI2C39AXgAeAh4V/pZlwH/nnn/36RJtiFwIrAceHW67Czgp8BYYIc04ZekyzZIk+40ktZ4V+AR4L1l/5/1dH8oO4A6PtKdbHnTvFvT1uNF4IA27/tX4Lx0urGjT8gsfxr4cOb1NcCn0+nrgaMyyzYAXqBFa9QmiU7NLD8HuD7z+v3A/CG2dxXwpnR6naQA/jaTRG8DHm9676xsgvbjw8dEnXka2FrSho0ZEfH2iNgiXbYBgKS3SbpF0lOSniPprm3d9FkrMtMvtni9aTq9MzBb0rOSngWeAUTSYuWR93uQdKKkhZKeS79r80zc25N09Rqy0zsD2zdiTN97CsnxYt9yEnXmF8BLwKHDrPdt4Dpgx4jYnKTr1unI3RPAjIjYIvN4TUTc2uHntSRpf+Ak4EPA2PSH4TleiXsZSTeuYcemGB9tinGziJjWzRirxknUgYh4Fvhn4OuS/krSppI2kDQJ2CSz6mbAMxHxB0n7Ah8p8LUXArMk7Q0gaXNJHyzwee1sRnK89xSwoaTTgNdmll+dxjFW0gRgZmbZL4HVkk6S9BpJYyS9XtJbexBnZTiJOhQRXwJOIBmdWknSPbqI5Fe80Tp8CjhD0vMkB9tXF/i+a4GzgaskrQbuAw7peAPa+xHJ8ddDwG9IBjuyXbYzgCUkI2//BXyXpFUmkvNS7wcmpct/C3yDpDvYt5Qe/Jl1RNLfA9Mj4sCyYymLWyIbEUnbSdov7b7uSTIEfm3ZcZVpw+FXMVvHRiTd1okkQ/pXAV8vNaKS9aw7J2kqyZnvMcA3IuKsnnyRWcl6kkTp1cMPAe8mOQi9AzgiIh7o+peZlaxX3bl9gUUR8QiApKtIzqm0TCJJHt2wKvptRGwz3Eq9GliYwLrDoktoOrMu6WhJd0q6s0cxmBX1mzwr9aolanVWfp3WJiLmAHPALZHVW69aoiWseznIDsDSHn2XWal6lUR3ALtLmihpI2A6yTVkZn2nJ925iFgjaSbJJSRjgLkRcX8vvsusbJW47MfHRFZRd0XEPsOt5Mt+zAqqxWU/xx9/fNkh2ACaPXt2rvXcEpkVVIuWaLTMmDEDgIsuuqjtsqzm9ZrXGelyqye3RKlWSdJq2UUXXfTyzp+dn03ATpZbfTmJUm4VrFNOohyyCTZjxowhu3btllv/chKZFeSBhZyGGyRoXset0eBwS5RDnoRw0gyuWlz2MxonW0c6PJ1nHQ9x19vs2bNzXfbjJDJrI28SuTtnVpCTyKwgj85VyNhZY9ebt+rMVSVEYiPhlqgiGgm06sxVLz+y8626nERmBXWcRJJ2TG9gtVDS/ZKOT+efLulJSfPTR1/fm8asyDHRGuDEiLhb0mbAXZJuTJedFxFfKR6eWfV1nEQRsYzkrmlExPOSFpL/1odmfaMrx0SSdgHeDNyezpopaYGkuZJaHhm7Auq6sgMJjUd2vlVX4SFuSZvyyl2uV0u6APg8ScXTz5PcqfqTze9zBdT1OWHqqVBLJOlVJAl0RUR8DyAiVkTE2oj4I3AxSXF7s75VZHROwDeBhRFxbmb+dpnVDie5t6hZ3yrSndsP+Chwr6T56bxTgCPSu2gH8BjgvxGwvlZkdO6/aX33h3mdh2NV5D/hGNrAXjt374NHrPP6DXteOaLl3fiMPN9RthkzZrSsMeFEeoUv+7EhOVmG5ySy3IYqbjnInESWm4tOtuYksiE5YYbnGgs2rEEdnctbY2FgR+csv0FJmk65O2dWkJPIrCAnkVlBA3NM1HyPoVZn4lstzz5nNc9rfNasWQ/3ahO64swzdy87hL4zUC3RcAfIeQ6gszfpyvse628DlUTDnfNoXt5q/Tzr2GAZqCRqbkVaLW+ebl6/1fvdGg22gUqiZp3c1a75Pa2Ol2yw+IoFszZG7YoFSY8BzwNrgTURsY+kLYHvALuQ/HXrhyLCVTisL3WrO/eOiJiUydqTgZsiYnfgpvS1WV/q1XmiQ4GD0ulLgZ8AJ/Xou0ZkJOeDWs1v9Z6sQ37+89HZkA5dv//+ZYfQd7qRRAH8OD2uuSitJzcurZBKRCyTtG0Xvqdrit4m0iyrG925/SJiMnAIcKykA/K8qcwKqCM9X9TpOjYYCidRRCxNn1cC15IUa1zRqD+XPq9s8b45EbFPntGPbhvplQvtXvv8kEHxCqibpHeEQNImwHtIijVeBxyZrnYk8IMi39Ntrc71DLXcbCiFzhNJ2pWk9YHk+OrbEfFFSVsBVwM7AY8DH4yIZ4b4HJ8nssoZlfNEEfEI8KYW858G3lnks83qohZXLJiVpH9qLEz+wuSyQ7ABdPdn7s61Xi2SaNsdKnWayWwdtUiiDa4e6IvNreJqkUTzd5g//EpmJalFEo3faXzZIdgAWsrSXOu5n2RWUC1aIg8sWJX5PJFZe7nOE7k7Z1aQk8isoFocE90w2Vcs2Oibene+KxbcEpkV5CQyK8hJZFZQLY6JJs3zFQtWgpy7nVsis4I6bokk7UlS5bRhV+A0YAvg74Cn0vmnRMS8jiMEPvLx04ZdZ9aJxwFw5jlfK/JVhTiGfosh327bcRJFxIPAJABJY4AnSeotfAI4LyK+0ulnd2LtSWuTiRKvEHIMgxlDt46J3gksjojfSOrSR47MmLPHJBPnlPL1jmGAY+hWEk0Hrsy8ninpY8CdwImjUcx+0H79HEN1Yig8sCBpI+ADwH+ksy4AdiPp6i2jzW9Btyugjjl7zCu/PiVxDIMZQzdaokOAuyNiBUDjGUDSxcAPW70prdk9J12v8FXcg/br5xiqE0M3kugIMl05Sds1itkDh5NURO25QeuHO4bqxFAoiST9CfBuIFtz90uSJpHcLeKxpmU9M2i/fo6hOjEUrYD6ArBV07yPFoqoQ4P26+cYqhNDLS77yWPQfv0cQ3Vi6JskGrRfP8dQnRj6JokG7dfPMVQnhr5JokH79XMM1Ymhb5Jo0H79HEN1YuibJBq0Xz/HUJ0Y+iaJBu3XzzFUJ4ZaFG9cvnzaaIVi9rLx4+e5eKPZaKhFd+6Wyb61ilWXWyKzgpxEZgU5icwKqsUx0TvunlR2CDaIxvtOeWajohYtUZ66c2bdl6/unFsis4JyJZGkuZJWSrovM29LSTdKejh9HpvOl6SvSlokaYEk31zI+lrelugSYGrTvJOBmyJid+Cm9DUk1X92Tx9Hk5TQMutbuZIoIn4GPNM0+1Dg0nT6UuCwzPzLInEbsIWk7boRrFkVFTkmGtcojZU+N66XnQA8kVlvSTpvHd0u3mhWll6MzrUqxr3eVdrdLt5oVpYiLdGKRjctfV6Zzl8C7JhZbwcg31krsxoqkkTXAUem00cCP8jM/1g6SjcFeC5TEdWs7+Tqzkm6EjgI2FrSEuBzwFnA1ZKOAh4HPpiuPg+YBiwCXiC5X5FZ38qVRBFxRJtF72yxbgDHFgnKrE58xYJZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQU4is4KcRGYFOYnMCnISmRXkJDIryElkVpCTyKwgJ5FZQcMmUZvqp1+W9Ou0wum1krZI5+8i6UVJ89PHhb0M3qwK8rREl7B+9dMbgddHxBuBh4BZmWWLI2JS+jimO2GaVdewSdSq+mlE/Dgi1qQvbyMpi2U2kLpxTPRJ4PrM64mSfiXpp5L2b/cmV0C1flGoAqqkU4E1wBXprGXAThHxtKS3AN+XtHdErG5+bzcroN58w5SXpw+eeluRj6p1DEOpenx11nFLJOlI4H3AX6dlsoiIlyLi6XT6LmAxsEc3Am0nu3OUpQoxjETd4q26jpJI0lTgJOADEfFCZv42ksak07uS3F7lkW4EmlcVdpAqxJBVtXj6zbDduTbVT2cBGwM3SgK4LR2JOwA4Q9IaYC1wTEQ035KlJxpdlDJ3mCrE0E6VY6u7YZOoTfXTb7ZZ9xrgmqJBdaKxc5TZ369CDK0cPPU2J08P1eLGx0M5eOptfO3tZ7z8+rhbBzOG4Sz41rSXpz/9Ld9Iupt82Y9ZQX2RRMfdeto6z4Maw1AarY9boe6rfXcOYI97FnAc5e4cZcVw/rmvBWDmCeudimux3lc4P72X+3DrW361b4n2uGfBOs+DFEMjgZqnh1ovz/o2MrVPoqwyE6lKMTScf+5rnSyjoLbduarsrGXG0eiSNRJluIRpXt+6oy9aoofe9MayQyg1huzxzcwTVrd83ZxAPibqntq2RNZacyvjVqf3+qIlstYtS3OrNNS61rnaJ9Ggd+WympOjMbCQTSYnUPfVPomyB/Zl7cxViGEo2WSy7qt9Etm6nCijr/YDC1X45a9CDFl77bXXeleS33zDlMpdXd4v3BKZFVTbJFo790DWzj1wnddlxVF2DMNxK9Rbte/OAex2/NiyQ6hEDA0HT71t3fND5z7gY6Ue6rQC6umSnsxUOp2WWTZL0iJJD0p6b68Cb6UKO3IVYmjmBOqtTiugApyXqXQ6D0DSXsB0YO/0PV9vFC7ptsWzV7F49ip2O34si2ev6sVX5I6j7BisXHlqLPxM0i45P+9Q4KqIeAl4VNIiYF/gFx1HmEMVduIqxGDlKDKwMDMtaD9XUqMPMwF4IrPOknTeerpVAbWx45bZjapCDFaeTpPoAmA3YBJJ1dNz0vlqsW7L6qYRMSci9omIfTqMYT1V2ImrEIMvOh1dHSVRRKyIiLUR8UfgYpIuGyQtz46ZVXcAlhYL0YrwoELvdVoBdbvMy8OBxsjddcB0SRtLmkhSAfWXxUIcWhV++asQg5Wn0wqoB0maRNJVewyYARAR90u6GniApND9sRGxtjehWyvuyo2+rlZATdf/IvDFIkHlUZVf/6rEYeWp7WU/rVRhiLkKMdjoUnpXlHKDGOb+RENd97Xf8icB+J/xLUfSR0UVYsiqak3wurn5hil35Rk9rsW1cydMbn/r19vnfRZIduS3Tfv8aIVUuRiybr4heR7q382G1/h3HE7tu3NV2GmrEEMr7/uX+WWHMBBq0Z0zK0n/dOd+eMqkskOwAZS3Ja99d86sbE4is4KcRGYFeWDBrD0PLJgV4YEFs1FSi+7c8uXThlps1hPjx8/rn+7cLZN95t2qy905s4KcRGYFOYnMCuq0Aup3MtVPH5M0P52/i6QXM8su7GXwZlWQZ2DhEuB84LLGjIj4cGNa0jnAc5n1F0dEV0/svONunyeyEozPV6iqUAVUSQI+BBw8gtBGbPz4eb38eLNCig5x7w+siIiHM/MmSvoVsBr4TET8vNUbJR0NHJ3nS67cfvuCYZqN3BFLu9QSDfc9wJWZ18uAnSLiaUlvAb4vae+IWK+CYETMAeaAr52zeus4iSRtCPwF8JbGvLSQ/Uvp9F2SFgN7AIXqbeeVPXZqnKBtNc8xlB/DaMTR7vu6/W9RZIj7XcCvI2JJY4akbRq3UpG0K0kF1EeKhTgyrf5RRvuKB8dQrRh6HUeeIe4rSW6NsqekJZKOShdNZ92uHMABwAJJ9wDfBY6JiGe6Fq1ZBXVaAZWI+HiLedcA1xQPy6w+fMWCWUF9mUTZ/m5ZV4A7hurE0Os4avGnECNRhasbHMNgxVCLP8rzyVYrwxFLl+b6o7xaJJFZSfrnL1uT619H5vI//WcAPvqLz3U7GMdQwxg6i2NmrrX6cmDBbDQ5icwKchKZFVSLY6Lx229Vynu7xTFUJwbIH8fyfH8J4ZbIrKhatETbjB/ZHbrPPfuznHDS5QBcfulnOeGk0b+TnWOoTgydxjGwLdEVl5zFuHGbvPx63LhNuOKSsxzDAMfQ6zjq0RJtu8WI39P8j9TJZxTlGKoTQy/jqMUVCyO9lfy3Lzljndcf+fhpIw+qIMdQnRg6jePmG6b0z2U/I00is27Im0R9d0xkNtry/Hn4jpJukbRQ0v2Sjk/nbynpRkkPp89j0/mS9FVJiyQtkDS51xthVqY8LdEa4MSIeB0wBThW0l7AycBNEbE7cFP6GuAQkgIlu5PUlbug61GbVciwSRQRyyLi7nT6eWAhMAE4FLg0Xe1S4LB0+lDgskjcBmwhabuuR25WESMa4k7LCb8ZuB0YFxHLIEk0Sdumq00Ansi8bUk6b1nTZ+WugHrzDVNGEqbZqMqdRJI2Jank8+mIWJ2U4W69aot5642+uQKq9Ytco3OSXkWSQFdExPfS2Ssa3bT0eWU6fwmwY+btOwA5L6Awq588o3MCvgksjIhzM4uuA45Mp48EfpCZ/7F0lG4K8Fyj22fWlyJiyAfwZyTdsQXA/PQxDdiKZFTu4fR5y3R9Af8GLAbuBfbJ8R3hhx8VfNw53L4bEfW4YsGsJL5iwWw0OInMCnISmRXkJDIrqCp/lPdb4Pfpc7/Ymv7Znn7aFsi/PTvn+bBKjM4BSLozz0hIXfTT9vTTtkD3t8fdObOCnERmBVUpieaUHUCX9dP29NO2QJe3pzLHRGZ1VaWWyKyWnERmBZWeRJKmSnowLWxy8vDvqB5Jj0m6V9J8SXem81oWcqkiSXMlrZR0X2ZebQvRtNme0yU9mf4fzZc0LbNsVro9D0p674i/MM+l3r16AGNI/mRiV2Aj4B5grzJj6nA7HgO2bpr3JeDkdPpk4Oyy4xwi/gOAycB9w8VP8mcw15P8ycsU4Pay48+5PacD/9hi3b3S/W5jYGK6P44ZyfeV3RLtCyyKiEci4n+Bq0gKnfSDdoVcKicifgY80zS7toVo2mxPO4cCV0XESxHxKLCIZL/MrewkalfUpG4C+LGku9ICLNBUyAXYtu27q6ld/HX+P5uZdkHnZrrXhben7CTKVdSkBvaLiMkkNfeOlXRA2QH1UF3/zy4AdgMmkVSeOiedX3h7yk6ivihqEhFL0+eVwLUk3YF2hVzqoq8K0UTEiohYGxF/BC7mlS5b4e0pO4nuAHaXNFHSRsB0kkIntSFpE0mbNaaB9wD30b6QS130VSGapuO2w0n+jyDZnumSNpY0kaRy7y9H9OEVGEmZBjxEMipyatnxdBD/riSjO/cA9ze2gTaFXKr4AK4k6eL8H8kv81Ht4qeDQjQV2Z7L03gXpImzXWb9U9PteRA4ZKTf58t+zAoquztnVntOIrOCnERmBTmJzApyEpkV5CQyK8hJZFbQ/wPTMFRqoBLrRQAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "s = env.reset()\n", + "for _ in range(100):\n", + " s, _, _, _ = env.step(env.action_space.sample())\n", + "\n", + "plt.title('Game image')\n", + "plt.imshow(env.render('rgb_array'))\n", + "plt.show()\n", + "\n", + "plt.title('Agent observation (4-frame buffer)')\n", + "plt.imshow(s.transpose([0,2,1]).reshape([42,-1]))\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Build an agent\n", + "\n", + "We now have to build an agent for actor-critic training - a convolutional neural network that converts states into action probabilities $\\pi$ and state values $V$.\n", + "\n", + "Your assignment here is to build and apply a neural network - with any framework you want. \n", + "\n", + "For starters, we want you to implement this architecture:\n", + "![nnet_arch.png](nnet_arch.png)\n", + "\n", + "After your agent gets mean reward above 50, we encourage you to experiment with model architecture to score even better." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "tf.reset_default_graph()\n", + "sess = tf.InteractiveSession()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import keras\n", + "from keras.layers import Conv2D, Dense, Flatten, Input\n", + "from keras.models import Model\n", + "\n", + "class Agent:\n", + " def __init__(self, name, state_shape, n_actions, reuse=False):\n", + " \"\"\"A simple actor-critic agent\"\"\"\n", + " \n", + " with tf.variable_scope(name, reuse=reuse):\n", + " \n", + " # Prepare neural network architecture\n", + " ### Your code here: prepare any necessary layers, variables, etc.\n", + " inputs = Input(shape=state_shape)\n", + " x = Conv2D(32, (3, 3), strides=2, activation='relu')(inputs)\n", + " x = Conv2D(32, (3, 3), strides=2, activation='relu')(x)\n", + " x = Conv2D(32, (3, 3), strides=2, activation='relu')(x)\n", + " x = Flatten()(x)\n", + " x = Dense(128, activation='relu')(x)\n", + " logits = Dense(n_actions, activation='linear')(x)\n", + " state_value = Dense(1, activation='linear')(x)\n", + " \n", + " self.network = Model(inputs=inputs, outputs=[logits, state_value])\n", + " \n", + " # prepare a graph for agent step\n", + " self.state_t = tf.placeholder('float32', [None,] + list(state_shape))\n", + " self.agent_outputs = self.symbolic_step(self.state_t)\n", + " \n", + " def symbolic_step(self, state_t):\n", + " \"\"\"Takes agent's previous step and observation, returns next state and whatever it needs to learn (tf tensors)\"\"\"\n", + " \n", + " # Apply neural network\n", + " ### Your code here: apply agent's neural network to get policy logits and state values.\n", + " \n", + " logits, state_value = self.network(state_t)\n", + " state_value = state_value[:, 0]\n", + " \n", + " assert tf.is_numeric_tensor(state_value) and state_value.shape.ndims == 1, \\\n", + " \"please return 1D tf tensor of state values [you got %s]\" % repr(state_value)\n", + " assert tf.is_numeric_tensor(logits) and logits.shape.ndims == 2, \\\n", + " \"please return 2d tf tensor of logits [you got %s]\" % repr(logits)\n", + " # hint: if you triggered state_values assert with your shape being [None, 1], \n", + " # just select [:, 0]-th element of state values as new state values\n", + " \n", + " return (logits, state_value)\n", + " \n", + " def step(self, state_t):\n", + " \"\"\"Same as symbolic step except it operates on numpy arrays\"\"\"\n", + " sess = tf.get_default_session()\n", + " return sess.run(self.agent_outputs, {self.state_t: state_t})\n", + " \n", + " def sample_actions(self, agent_outputs):\n", + " \"\"\"pick actions given numeric agent outputs (np arrays)\"\"\"\n", + " logits, state_values = agent_outputs\n", + " policy = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n", + " return np.array([np.random.choice(len(p), p=p) for p in policy])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "agent = Agent(\"agent\", obs_shape, n_actions)\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "action logits:\n", + " [[ 0.07089273 0.02756768 -0.05570034 0.06700554 0.06069527 -0.01415996\n", + " -0.00033508 0.00273614 -0.04520407 0.00565409 -0.08711416 0.03533225\n", + " -0.00779367 0.06893449]]\n", + "state values:\n", + " [0.05507889]\n" + ] + } + ], + "source": [ + "state = [env.reset()]\n", + "logits, value = agent.step(state)\n", + "print(\"action logits:\\n\", logits)\n", + "print(\"state values:\\n\", value)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's play!\n", + "Let's build a function that measures agent's average reward." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate(agent, env, n_games=1):\n", + " \"\"\"Plays an a game from start till done, returns per-game rewards \"\"\"\n", + "\n", + " game_rewards = []\n", + " for _ in range(n_games):\n", + " state = env.reset()\n", + " \n", + " total_reward = 0\n", + " while True:\n", + " action = agent.sample_actions(agent.step([state]))[0]\n", + " state, reward, done, info = env.step(action)\n", + " total_reward += reward\n", + " if done: break\n", + " \n", + " game_rewards.append(total_reward)\n", + " return game_rewards\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2.0, 7.0, 11.0]\n" + ] + } + ], + "source": [ + "env_monitor = gym.wrappers.Monitor(env, directory=\"kungfu_videos\", force=True)\n", + "rw = evaluate(agent, env_monitor, n_games=3,)\n", + "env_monitor.close()\n", + "print (rw)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#show video\n", + "from IPython.display import HTML\n", + "import os\n", + "\n", + "video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./kungfu_videos/\")))\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(\"./kungfu_videos/\"+video_names[-1])) #this may or may not be _last_ video. Try other indices\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training on parallel games\n", + "![img](env_pool.png)\n", + "\n", + "To make actor-critic training more stable, we shall play several games in parallel. This means ya'll have to initialize several parallel gym envs, send agent's actions there and .reset() each env if it becomes terminated. To minimize learner brain damage, we've taken care of them for ya - just make sure you read it before you use it.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "class EnvBatch:\n", + " def __init__(self, n_envs = 10):\n", + " \"\"\" Creates n_envs environments and babysits them for ya' \"\"\"\n", + " self.envs = [make_env() for _ in range(n_envs)]\n", + " \n", + " def reset(self):\n", + " \"\"\" Reset all games and return [n_envs, *obs_shape] observations \"\"\"\n", + " return np.array([env.reset() for env in self.envs])\n", + " \n", + " def step(self, actions):\n", + " \"\"\"\n", + " Send a vector[batch_size] of actions into respective environments\n", + " :returns: observations[n_envs, *obs_shape], rewards[n_envs], done[n_envs,], info[n_envs]\n", + " \"\"\"\n", + " results = [env.step(a) for env, a in zip(self.envs, actions)]\n", + " new_obs, rewards, done, infos = map(np.array, zip(*results))\n", + " \n", + " # reset environments automatically\n", + " for i in range(len(self.envs)):\n", + " if done[i]:\n", + " new_obs[i] = self.envs[i].reset()\n", + " \n", + " return new_obs, rewards, done, infos" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "__Let's try it out:__" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "\u001b[33mWARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\u001b[0m\n", + "State shape: (10, 42, 42, 4)\n", + "Actions: [13 3 0]\n", + "Rewards: [0. 0. 0.]\n", + "Done: [False False False]\n" + ] + } + ], + "source": [ + "env_batch = EnvBatch(10)\n", + "\n", + "batch_states = env_batch.reset()\n", + "\n", + "batch_actions = agent.sample_actions(agent.step(batch_states))\n", + "\n", + "batch_next_states, batch_rewards, batch_done, _ = env_batch.step(batch_actions)\n", + "\n", + "print(\"State shape:\", batch_states.shape)\n", + "print(\"Actions:\", batch_actions[:3])\n", + "print(\"Rewards:\", batch_rewards[:3])\n", + "print(\"Done:\", batch_done[:3])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Actor-critic\n", + "\n", + "Here we define a loss functions and learning algorithms as usual." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# These placeholders mean exactly the same as in \"Let's try it out\" section above\n", + "states_ph = tf.placeholder('float32', [None,] + list(obs_shape)) \n", + "next_states_ph = tf.placeholder('float32', [None,] + list(obs_shape))\n", + "actions_ph = tf.placeholder('int32', (None,))\n", + "rewards_ph = tf.placeholder('float32', (None,))\n", + "is_done_ph = tf.placeholder('float32', (None,))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# logits[n_envs, n_actions] and state_values[n_envs, n_actions]\n", + "logits, state_values = agent.symbolic_step(states_ph)\n", + "next_logits, next_state_values = agent.symbolic_step(next_states_ph)\n", + "next_state_values = next_state_values * (1 - is_done_ph)\n", + "\n", + "# probabilities and log-probabilities for all actions\n", + "probs = tf.nn.softmax(logits) # [n_envs, n_actions]\n", + "logprobs = tf.nn.log_softmax(logits) # [n_envs, n_actions]\n", + "\n", + "# log-probabilities only for agent's chosen actions\n", + "logp_actions = tf.reduce_sum(logprobs * tf.one_hot(actions_ph, n_actions), axis=-1) # [n_envs,]" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "# compute advantage using rewards_ph, state_values and next_state_values\n", + "gamma = 0.99\n", + "advantage = rewards_ph + gamma*next_state_values - state_values\n", + "\n", + "assert advantage.shape.ndims == 1, \"please compute advantage for each sample, vector of shape [n_envs,]\"\n", + "\n", + "# compute policy entropy given logits_seq. Mind the \"-\" sign!\n", + "entropy = -tf.reduce_sum(probs*logprobs, axis=1)\n", + "\n", + "assert entropy.shape.ndims == 1, \"please compute pointwise entropy vector of shape [n_envs,] \"\n", + "\n", + "\n", + "\n", + "actor_loss = - tf.reduce_mean(logp_actions * tf.stop_gradient(advantage)) - 0.001 * tf.reduce_mean(entropy)\n", + "\n", + "# compute target state values using temporal difference formula. Use rewards_ph and next_step_values\n", + "target_state_values = rewards_ph + gamma*next_state_values\n", + "\n", + "critic_loss = tf.reduce_mean((state_values - tf.stop_gradient(target_state_values))**2 )\n", + "\n", + "train_step = tf.train.AdamOptimizer(1e-4).minimize(actor_loss + critic_loss)\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "You just might be fine!\n" + ] + } + ], + "source": [ + "# Sanity checks to catch some errors. Specific to KungFuMaster in assignment's default setup.\n", + "l_act, l_crit, adv, ent = sess.run([actor_loss, critic_loss, advantage, entropy], feed_dict = {\n", + " states_ph: batch_states,\n", + " actions_ph: batch_actions,\n", + " next_states_ph: batch_states,\n", + " rewards_ph: batch_rewards,\n", + " is_done_ph: batch_done,\n", + " })\n", + "\n", + "assert abs(l_act) < 100 and abs(l_crit) < 100, \"losses seem abnormally large\"\n", + "assert 0 <= ent.mean() <= np.log(n_actions), \"impossible entropy value, double-check the formula pls\"\n", + "if ent.mean() < np.log(n_actions) / 2: print(\"Entropy is too low for untrained agent\")\n", + "print(\"You just might be fine!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train \n", + "\n", + "Just the usual - play a bit, compute loss, follow the graidents, repeat a few million times.\n", + "![img](http://images6.fanpop.com/image/photos/38900000/Daniel-san-training-the-karate-kid-38947361-499-288.gif)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n" + ] + } + ], + "source": [ + "from IPython.display import clear_output\n", + "from tqdm import trange\n", + "from pandas import DataFrame\n", + "ewma = lambda x, span=100: DataFrame({'x':np.asarray(x)}).x.ewm(span=span).mean().values\n", + "\n", + "env_batch = EnvBatch(10)\n", + "batch_states = env_batch.reset()\n", + "\n", + "rewards_history = []\n", + "entropy_history = []" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████████████████████████████████████████████████████████████████████| 100000/100000 [35:42<00:00, 46.67it/s]\n" + ] + } + ], + "source": [ + "for i in trange(100000): \n", + " \n", + " batch_actions = agent.sample_actions(agent.step(batch_states))\n", + " batch_next_states, batch_rewards, batch_done, _ = env_batch.step(batch_actions)\n", + " \n", + " feed_dict = {\n", + " states_ph: batch_states,\n", + " actions_ph: batch_actions,\n", + " next_states_ph: batch_next_states,\n", + " rewards_ph: batch_rewards,\n", + " is_done_ph: batch_done,\n", + " }\n", + " batch_states = batch_next_states\n", + " \n", + " _, ent_t = sess.run([train_step, entropy], feed_dict)\n", + " entropy_history.append(np.mean(ent_t))\n", + "\n", + " if i % 500 == 0: \n", + " if i % 2500 == 0:\n", + " rewards_history.append(np.mean(evaluate(agent, env, n_games=3)))\n", + " if rewards_history[-1] >= 50:\n", + " print(\"Your agent has earned the yellow belt\")\n", + "\n", + " clear_output(True)\n", + " plt.figure(figsize=[8,4])\n", + " plt.subplot(1,2,1)\n", + " plt.plot(rewards_history, label='rewards')\n", + " plt.plot(ewma(np.array(rewards_history),span=10), marker='.', label='rewards ewma@10')\n", + " plt.title(\"Session rewards\"); plt.grid(); plt.legend()\n", + " \n", + " plt.subplot(1,2,2)\n", + " plt.plot(entropy_history, label='entropy')\n", + " plt.plot(ewma(np.array(entropy_history),span=1000), label='entropy ewma@1000')\n", + " plt.title(\"Policy entropy\"); plt.grid(); plt.legend() \n", + " plt.show()\n", + " \n", + " \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Relax and grab some refreshments while your agent is locked in an infinite loop of violence and death.\n", + "\n", + "__How to interpret plots:__\n", + "\n", + "The session reward is the easy thing: it should in general go up over time, but it's okay if it fluctuates ~~like crazy~~. It's also OK if it reward doesn't increase substantially before some 10k initial steps. However, if reward reaches zero and doesn't seem to get up over 2-3 evaluations, there's something wrong happening.\n", + "\n", + "\n", + "Since we use a policy-based method, we also keep track of __policy entropy__ - the same one you used as a regularizer. The only important thing about it is that your entropy shouldn't drop too low (`< 0.1`) before your agent gets the yellow belt. Or at least it can drop there, but _it shouldn't stay there for long_.\n", + "\n", + "If it does, the culprit is likely:\n", + "* Some bug in entropy computation. Remember that it is $ - \\sum p(a_i) \\cdot log p(a_i) $\n", + "* Your agent architecture converges too fast. Increase entropy coefficient in actor loss. \n", + "* Gradient explosion - just [clip gradients](https://stackoverflow.com/a/43486487) and maybe use a smaller network\n", + "* Us. Or TF developers. Or aliens. Or lizardfolk. Contact us on forums before it's too late!\n", + "\n", + "If you're debugging, just run `logits, values = agent.step(batch_states)` and manually look into logits and values. This will reveal the problem 9 times out of 10: you'll likely see some NaNs or insanely large numbers or zeros. Try to catch the moment when this happens for the first time and investigate from there." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### \"Final\" evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final mean reward: 227.0\n" + ] + } + ], + "source": [ + "env_monitor = gym.wrappers.Monitor(env, directory=\"kungfu_videos\", force=True)\n", + "final_rewards = evaluate(agent, env_monitor, n_games=20,)\n", + "env_monitor.close()\n", + "print(\"Final mean reward:\", np.mean(final_rewards))\n", + "\n", + "video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./kungfu_videos/\")))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(\"./kungfu_videos/\"+video_names[-1])) " + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(\"./kungfu_videos/\"+video_names[-2])) #try other indices " + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# if you don't see videos, just navigate to ./kungfu_videos and download .mp4 files from there." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARN: gym.spaces.Box autodetected dtype as . Please provide explicit dtype.\n", + "Submitted to Coursera platform. See results on assignment page!\n" + ] + } + ], + "source": [ + "from submit import submit_kungfu\n", + "env = make_env()\n", + "submit_kungfu(agent, env, evaluate, '', '')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n", + "```\n", + "\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now what?\n", + "Well, 5k reward is [just the beginning](https://www.buzzfeed.com/mattjayyoung/what-the-color-of-your-karate-belt-actually-means-lg3g). Can you get past 200? With recurrent neural network memory, chances are you can even beat 400!\n", + "\n", + "* Try n-step advantage and \"lambda\"-advantage (aka GAE) - see [this article](https://arxiv.org/abs/1506.02438)\n", + " * This change should improve early convergence a lot\n", + "* Try recurrent neural network \n", + " * RNN memory will slow things down initially, but in will reach better final reward at this game\n", + "* Implement asynchronuous version\n", + " * Remember [A3C](https://arxiv.org/abs/1602.01783)? The first \"A\" stands for asynchronuous. It means there are several parallel actor-learners out there.\n", + " * You can write custom code for synchronization, but we recommend using [redis](https://redis.io/)\n", + " * You can store full parameter set in redis, along with any other metadate\n", + " * Here's a _quick_ way to (de)serialize parameters for redis\n", + " ```\n", + " import joblib\n", + " from six import BytesIO\n", + "```\n", + "```\n", + " def dumps(data):\n", + " \"converts whatever to string\"\n", + " s = BytesIO()\n", + " joblib.dump(data,s)\n", + " return s.getvalue()\n", + "``` \n", + "```\n", + " def loads(string):\n", + " \"converts string to whatever was dumps'ed in it\"\n", + " return joblib.load(BytesIO(string))\n", + "```" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/Practical Reinforcement Learning/Week5_policy_based/practice_reinforce.ipynb b/Practical Reinforcement Learning/Week5_policy_based/practice_reinforce.ipynb new file mode 100644 index 0000000..d72036c --- /dev/null +++ b/Practical Reinforcement Learning/Week5_policy_based/practice_reinforce.ipynb @@ -0,0 +1,450 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# REINFORCE in TensorFlow\n", + "\n", + "This notebook implements a basic reinforce algorithm a.k.a. policy gradient for CartPole env.\n", + "\n", + "It has been deliberately written to be as simple and human-readable.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The notebook assumes that you have [openai gym](https://github.com/openai/gym) installed.\n", + "\n", + "In case you're running on a server, [use xvfb](https://github.com/openai/gym#rendering-on-a-server)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD8CAYAAAB9y7/cAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAEpBJREFUeJzt3XGs3eV93/H3p5hAlmQ1hAtybTOTxmtDp8XQO+KIaaKQtsC6mUpNBZsaFCFdJhEpUaOt0ElrIg2pldawRetQ3ELjTFkII0lxEW3KHKIqfwRiJ45j41BuEie+tYfNAiRZNDaT7/64z01OzPG9x/fc6+v75P2Sjs7v9/ye8zvfBw6f+7vP/T2cVBWSpP781EoXIElaHga8JHXKgJekThnwktQpA16SOmXAS1Knli3gk1yf5Okk00nuXK73kSQNl+W4Dz7JOcDfAL8MzABfAG6pqqeW/M0kSUMt1xX8VcB0VX29qv4v8ACwbZneS5I0xJplOu964PDA/gzwllN1vuiii2rTpk3LVIokrT6HDh3iueeeyzjnWK6AH1bUj80FJZkCpgAuvfRSdu/evUylSNLqMzk5OfY5lmuKZgbYOLC/ATgy2KGqtlfVZFVNTkxMLFMZkvSTa7kC/gvA5iSXJXkVcDOwc5neS5I0xLJM0VTViSTvAj4NnAPcX1UHluO9JEnDLdccPFX1KPDocp1fkjQ/V7JKUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SerUWF/Zl+QQ8F3gZeBEVU0muRD4OLAJOAT8ZlU9P16ZkqTTtRRX8L9UVVuqarLt3wnsqqrNwK62L0k6w5ZjimYbsKNt7wBuWob3kCQtYNyAL+CvkuxJMtXaLqmqowDt+eIx30OStAhjzcEDV1fVkSQXA48l+eqoL2w/EKYALr300jHLkCSdbKwr+Ko60p6PAZ8CrgKeTbIOoD0fO8Vrt1fVZFVNTkxMjFOGJGmIRQd8ktcked3cNvArwH5gJ3Br63Yr8PC4RUqSTt84UzSXAJ9KMnee/1ZVf5nkC8CDSW4DvgW8ffwyJUmna9EBX1VfB948pP1/AdeNU5QkaXyuZJWkThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6tWDAJ7k/ybEk+wfaLkzyWJJn2vMFrT1JPphkOsm+JFcuZ/GSpFMb5Qr+w8D1J7XdCeyqqs3ArrYPcAOwuT2mgHuXpkxJ0ulaMOCr6q+Bb5/UvA3Y0bZ3ADcNtH+kZn0eWJtk3VIVK0ka3WLn4C+pqqMA7fni1r4eODzQb6a1vUKSqSS7k+w+fvz4IsuQJJ3KUv+RNUPaaljHqtpeVZNVNTkxMbHEZUiSFhvwz85NvbTnY619Btg40G8DcGTx5UmSFmuxAb8TuLVt3wo8PND+jnY3zVbgxbmpHEnSmbVmoQ5JPgZcA1yUZAb4PeD3gQeT3AZ8C3h76/4ocCMwDXwfeOcy1CxJGsGCAV9Vt5zi0HVD+hZwx7hFSZLG50pWSeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdWjDgk9yf5FiS/QNt70vyt0n2tseNA8fuSjKd5Okkv7pchUuS5jfKFfyHgeuHtN9TVVva41GAJJcDNwO/0F7zX5Kcs1TFSpJGt2DAV9VfA98e8XzbgAeq6qWq+gYwDVw1Rn2SpEUaZw7+XUn2tSmcC1rbeuDwQJ+Z1vYKSaaS7E6y+/jx42OUIUkaZrEBfy/ws8AW4Cjwh609Q/rWsBNU1faqmqyqyYmJiUWWIUk6lUUFfFU9W1UvV9UPgD/mR9MwM8DGga4bgCPjlShJWoxFBXySdQO7vw7M3WGzE7g5yXlJLgM2A0+OV6IkaTHWLNQhyceAa4CLkswAvwdck2QLs9Mvh4DbAarqQJIHgaeAE8AdVfXy8pQuSZrPggFfVbcMab5vnv53A3ePU5QkaXyuZJWkThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdWvA2SeknwZ7ttw9t/8WpD53hSqSl4xW8JHXKgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUqQUDPsnGJI8nOZjkQJJ3t/YLkzyW5Jn2fEFrT5IPJplOsi/Jlcs9CEnSK41yBX8CeG9VvQnYCtyR5HLgTmBXVW0GdrV9gBuAze0xBdy75FVLkha0YMBX1dGq+mLb/i5wEFgPbAN2tG47gJva9jbgIzXr88DaJOuWvHJJ0rxOaw4+ySbgCuAJ4JKqOgqzPwSAi1u39cDhgZfNtLaTzzWVZHeS3cePHz/9yiVJ8xo54JO8FvgE8J6q+s58XYe01SsaqrZX1WRVTU5MTIxahiRpRCMFfJJzmQ33j1bVJ1vzs3NTL+35WGufATYOvHwDcGRpypUkjWqUu2gC3AccrKoPDBzaCdzatm8FHh5of0e7m2Yr8OLcVI4k6cwZ5Sv7rgZ+C/hKkr2t7XeB3wceTHIb8C3g7e3Yo8CNwDTwfeCdS1qxJGkkCwZ8VX2O4fPqANcN6V/AHWPWJUkakytZJalTBrwkdcqAl6ROGfCS1CkDXjqFX5z60EqXII3FgJekThnwktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SerUKF+6vTHJ40kOJjmQ5N2t/X1J/jbJ3va4ceA1dyWZTvJ0kl9dzgFIkoYb5Uu3TwDvraovJnkdsCfJY+3YPVX1HwY7J7kcuBn4BeBngP+R5O9X1ctLWbgkaX4LXsFX1dGq+mLb/i5wEFg/z0u2AQ9U1UtV9Q1gGrhqKYqVJI3utObgk2wCrgCeaE3vSrIvyf1JLmht64HDAy+bYf4fCJKkZTBywCd5LfAJ4D1V9R3gXuBngS3AUeAP57oOeXkNOd9Ukt1Jdh8/fvy0C5ckzW+kgE9yLrPh/tGq+iRAVT1bVS9X1Q+AP+ZH0zAzwMaBl28Ajpx8zqraXlWTVTU5MTExzhgkSUOMchdNgPuAg1X1gYH2dQPdfh3Y37Z3AjcnOS/JZcBm4MmlK1mSNIpR7qK5Gvgt4CtJ9ra23wVuSbKF2emXQ8DtAFV1IMmDwFPM3oFzh3fQSNKZt2DAV9XnGD6v/ug8r7kbuHuMuiRJY3IlqyR1yoCXpE4Z8PqJt2f77StdgrQsDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLg1aUkIz+W8xzSSjLgJalTo3zhh9S9Pz8y9WP7/+xntq9QJdLS8QpeP/FODnepFwa8NIShrx6M8qXb5yd5MsmXkxxI8v7WflmSJ5I8k+TjSV7V2s9r+9Pt+KblHYK09JyiUQ9GuYJ/Cbi2qt4MbAGuT7IV+APgnqraDDwP3Nb63wY8X1VvBO5p/aSzlmGuXo3ypdsFfK/tntseBVwL/IvWvgN4H3AvsK1tAzwE/OckaeeRzjqTt28Hfjzk378ypUhLaqS7aJKcA+wB3gj8EfA14IWqOtG6zADr2/Z64DBAVZ1I8iLweuC5U51/z5493kusVc3Pr85GIwV8Vb0MbEmyFvgU8KZh3drzsE/6K67ek0wBUwCXXnop3/zmN0cqWBrFmQ5cf0HVUpucnBz7HKd1F01VvQB8FtgKrE0y9wNiA3Ckbc8AGwHa8Z8Gvj3kXNurarKqJicmJhZXvSTplEa5i2aiXbmT5NXA24CDwOPAb7RutwIPt+2dbZ92/DPOv0vSmTfKFM06YEebh/8p4MGqeiTJU8ADSf498CXgvtb/PuC/Jplm9sr95mWoW5K0gFHuotkHXDGk/evAVUPa/w/w9iWpTpK0aK5klaROGfCS1CkDXpI65f8uWF3yxi3JK3hJ6pYBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1KlRvnT7/CRPJvlykgNJ3t/aP5zkG0n2tseW1p4kH0wynWRfkiuXexCSpFca5f8H/xJwbVV9L8m5wOeS/EU79q+r6qGT+t8AbG6PtwD3tmdJ0hm04BV8zfpe2z23Peb7NoVtwEfa6z4PrE2ybvxSJUmnY6Q5+CTnJNkLHAMeq6on2qG72zTMPUnOa23rgcMDL59pbZKkM2ikgK+ql6tqC7ABuCrJPwDuAn4e+EfAhcDvtO4ZdoqTG5JMJdmdZPfx48cXVbwk6dRO6y6aqnoB+CxwfVUdbdMwLwF/ClzVus0AGwdetgE4MuRc26tqsqomJyYmFlW8JOnURrmLZiLJ2rb9auBtwFfn5tWTBLgJ2N9eshN4R7ubZivwYlUdXZbqJUmnNMpdNOuAHUnOYfYHwoNV9UiSzySZYHZKZi/wr1r/R4EbgWng+8A7l75sSdJCFgz4qtoHXDGk/dpT9C/gjvFLkySNw5WsktQpA16SOmXAS1KnDHhJ6pQBL0mdMuAlqVMGvCR1yoCXpE4Z8JLUKQNekjplwEtSpwx4SeqUAS9JnTLgJalTBrwkdcqAl6ROGfCS1CkDXpI6ZcBLUqdGDvgk5yT5UpJH2v5lSZ5I8kySjyd5VWs/r+1Pt+Oblqd0SdJ8TucK/t3AwYH9PwDuqarNwPPAba39NuD5qnojcE/rJ0k6w0YK+CQbgH8K/EnbD3At8FDrsgO4qW1va/u049e1/pKkM2jNiP3+I/BvgNe1/dcDL1TVibY/A6xv2+uBwwBVdSLJi63/c4MnTDIFTLXdl5LsX9QIzn4XcdLYO9HruKDfsTmu1eXvJZmqqu2LPcGCAZ/k14BjVbUnyTVzzUO61gjHftQwW/T29h67q2pypIpXmV7H1uu4oN+xOa7VJ8luWk4uxihX8FcD/zzJjcD5wN9l9op+bZI17Sp+A3Ck9Z8BNgIzSdYAPw18e7EFSpIWZ8E5+Kq6q6o2VNUm4GbgM1X1L4HHgd9o3W4FHm7bO9s+7fhnquoVV/CSpOU1zn3wvwP8dpJpZufY72vt9wGvb+2/Ddw5wrkW/SvIKtDr2HodF/Q7Nse1+ow1tnhxLUl9ciWrJHVqxQM+yfVJnm4rX0eZzjmrJLk/ybHB2zyTXJjksbbK97EkF7T2JPlgG+u+JFeuXOXzS7IxyeNJDiY5kOTdrX1Vjy3J+UmeTPLlNq73t/YuVmb3uuI8yaEkX0myt91Zsuo/iwBJ1iZ5KMlX239rb13Kca1owCc5B/gj4AbgcuCWJJevZE2L8GHg+pPa7gR2tVW+u/jR3yFuADa3xxRw7xmqcTFOAO+tqjcBW4E72r+b1T62l4Brq+rNwBbg+iRb6Wdlds8rzn+pqrYM3BK52j+LAP8J+Muq+nngzcz+u1u6cVXVij2AtwKfHti/C7hrJWta5Dg2AfsH9p8G1rXtdcDTbftDwC3D+p3tD2bvkvrlnsYG/B3gi8BbmF0os6a1//BzCXwaeGvbXtP6ZaVrP8V4NrRAuBZ4hNk1Kat+XK3GQ8BFJ7Wt6s8is7ecf+Pkf+5LOa6VnqL54arXZnBF7Gp2SVUdBWjPF7f2VTne9uv7FcATdDC2No2xFzgGPAZ8jRFXZgNzK7PPRnMrzn/Q9kdecc7ZPS6YXSz5V0n2tFXwsPo/i28AjgN/2qbV/iTJa1jCca10wI+06rUjq268SV4LfAJ4T1V9Z76uQ9rOyrFV1ctVtYXZK96rgDcN69aeV8W4MrDifLB5SNdVNa4BV1fVlcxOU9yR5J/M03e1jG0NcCVwb1VdAfxv5r+t/LTHtdIBP7fqdc7gitjV7Nkk6wDa87HWvqrGm+RcZsP9o1X1ydbcxdgAquoF4LPM/o1hbVt5DcNXZnOWr8yeW3F+CHiA2WmaH644b31W47gAqKoj7fkY8ClmfzCv9s/iDDBTVU+0/YeYDfwlG9dKB/wXgM3tL/2vYnal7M4VrmkpDK7mPXmV7zvaX8O3Ai/O/Sp2tkkSZhetHayqDwwcWtVjSzKRZG3bfjXwNmb/sLWqV2ZXxyvOk7wmyevmtoFfAfazyj+LVfU/gcNJfq41XQc8xVKO6yz4Q8ONwN8wOw/6b1e6nkXU/zHgKPD/mP0Jexuzc5m7gGfa84Wtb5i9a+hrwFeAyZWuf55x/WNmf/3bB+xtjxtX+9iAfwh8qY1rP/DvWvsbgCeBaeC/A+e19vPb/nQ7/oaVHsMIY7wGeKSXcbUxfLk9DszlxGr/LLZatwC72+fxz4ALlnJcrmSVpE6t9BSNJGmZGPCS1CkDXpI6ZcBLUqcMeEnqlAEvSZ0y4CWpUwa8JHXq/wMtNoHN6fAOgwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import gym\n", + "import numpy as np, pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "\n", + "env = gym.make(\"CartPole-v0\")\n", + "\n", + "#gym compatibility: unwrap TimeLimit\n", + "if hasattr(env,'env'):\n", + " env=env.env\n", + "\n", + "env.reset()\n", + "n_actions = env.action_space.n\n", + "state_dim = env.observation_space.shape\n", + "\n", + "plt.imshow(env.render(\"rgb_array\"))\n", + "env.close()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Building the policy network" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For REINFORCE algorithm, we'll need a model that predicts action probabilities given states.\n", + "\n", + "For numerical stability, please __do not include the softmax layer into your network architecture__. \n", + "\n", + "We'll use softmax or log-softmax where appropriate." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "#create input variables. We only need for REINFORCE\n", + "states = tf.placeholder('float32',(None,)+state_dim,name=\"states\")\n", + "actions = tf.placeholder('int32',name=\"action_ids\")\n", + "cumulative_rewards = tf.placeholder('float32', name=\"cumulative_returns\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "import keras\n", + "from keras.layers import Dense\n", + "\n", + "model = keras.models.Sequential()\n", + "model.add(Dense(128, activation='relu', input_shape=state_dim))\n", + "model.add(Dense(64, activation='relu'))\n", + "model.add(Dense(n_actions, activation='linear'))\n", + "\n", + "logits = model(states)\n", + "\n", + "policy = tf.nn.softmax(logits)\n", + "log_policy = tf.nn.log_softmax(logits)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "#utility function to pick action in one given state\n", + "get_action_proba = lambda s: policy.eval({states:[s]})[0] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Loss function and updates\n", + "\n", + "We now need to define objective and update over policy gradient.\n", + "\n", + "Our objective function is\n", + "\n", + "$$ J \\approx { 1 \\over N } \\sum _{s_i,a_i} \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n", + "\n", + "\n", + "Following the REINFORCE algorithm, we can define our objective as follows: \n", + "\n", + "$$ \\hat J \\approx { 1 \\over N } \\sum _{s_i,a_i} log \\pi_\\theta (a_i | s_i) \\cdot G(s_i,a_i) $$\n", + "\n", + "When you compute gradient of that function over network weights $ \\theta $, it will become exactly the policy gradient.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "#get probabilities for parti\n", + "indices = tf.stack([tf.range(tf.shape(log_policy)[0]),actions],axis=-1)\n", + "log_policy_for_actions = tf.gather_nd(log_policy,indices)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# policy objective as in the last formula. please use mean, not sum.\n", + "# note: you need to use log_policy_for_actions to get log probabilities for actions taken\n", + "\n", + "J = tf.reduce_mean(log_policy_for_actions*cumulative_rewards)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "#regularize with entropy\n", + "entropy = -tf.reduce_sum(policy*log_policy, axis=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#all network weights\n", + "all_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)\n", + "\n", + "#weight updates. maximizing J is same as minimizing -J. Adding negative entropy.\n", + "loss = -J -0.1 * entropy\n", + "\n", + "update = tf.train.AdamOptimizer().minimize(loss,var_list=all_weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Computing cumulative rewards" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def get_cumulative_rewards(rewards, #rewards at each step\n", + " gamma = 0.99 #discount for reward\n", + " ):\n", + " \"\"\"\n", + " take a list of immediate rewards r(s,a) for the whole session \n", + " compute cumulative rewards R(s,a) (a.k.a. G(s,a) in Sutton '16)\n", + " R_t = r_t + gamma*r_{t+1} + gamma^2*r_{t+2} + ...\n", + " \n", + " The simple way to compute cumulative rewards is to iterate from last to first time tick\n", + " and compute R_t = r_t + gamma*R_{t+1} recurrently\n", + " \n", + " You must return an array/list of cumulative rewards with as many elements as in the initial rewards.\n", + " \"\"\"\n", + " \n", + " cumulative_rewards = np.array(rewards).astype(np.float32)\n", + " for i in range(len(rewards)-2, -1, -1):\n", + " cumulative_rewards[i] += gamma*cumulative_rewards[i+1]\n", + " \n", + " return cumulative_rewards\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "looks good!\n" + ] + } + ], + "source": [ + "assert len(get_cumulative_rewards(range(100))) == 100\n", + "assert np.allclose(get_cumulative_rewards([0,0,1,0,0,1,0],gamma=0.9),[1.40049, 1.5561, 1.729, 0.81, 0.9, 1.0, 0.0])\n", + "assert np.allclose(get_cumulative_rewards([0,0,1,-2,3,-4,0],gamma=0.5), [0.0625, 0.125, 0.25, -1.5, 1.0, -4.0, 0.0])\n", + "assert np.allclose(get_cumulative_rewards([0,0,1,2,3,4,0],gamma=0), [0, 0, 1, 2, 3, 4, 0])\n", + "print(\"looks good!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def train_step(_states,_actions,_rewards):\n", + " \"\"\"given full session, trains agent with policy gradient\"\"\"\n", + " _cumulative_rewards = get_cumulative_rewards(_rewards)\n", + " update.run({states:_states,actions:_actions,cumulative_rewards:_cumulative_rewards})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Playing the game" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_session(t_max=1000):\n", + " \"\"\"play env with REINFORCE agent and train at the session end\"\"\"\n", + " \n", + " #arrays to record session\n", + " states,actions,rewards = [],[],[]\n", + " \n", + " s = env.reset()\n", + " \n", + " for t in range(t_max):\n", + " \n", + " #action probabilities array aka pi(a|s)\n", + " action_probas = get_action_proba(s)\n", + " \n", + " a = np.random.choice(n_actions, p=action_probas)\n", + " \n", + " new_s,r,done,info = env.step(a)\n", + " \n", + " #record session history to train later\n", + " states.append(s)\n", + " actions.append(a)\n", + " rewards.append(r)\n", + " \n", + " s = new_s\n", + " if done: break\n", + " \n", + " train_step(states,actions,rewards)\n", + " \n", + " return sum(rewards)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean reward:60.590\n", + "mean reward:160.860\n", + "mean reward:165.670\n", + "mean reward:451.360\n", + "You Win!\n" + ] + } + ], + "source": [ + "s = tf.InteractiveSession()\n", + "s.run(tf.global_variables_initializer())\n", + "\n", + "for i in range(100):\n", + " \n", + " rewards = [generate_session() for _ in range(100)] #generate new sessions\n", + " \n", + " print (\"mean reward:%.3f\"%(np.mean(rewards)))\n", + "\n", + " if np.mean(rewards) > 300:\n", + " print (\"You Win!\")\n", + " break\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Results & video" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "#record sessions\n", + "import gym.wrappers\n", + "env = gym.wrappers.Monitor(gym.make(\"CartPole-v0\"),directory=\"videos\",force=True)\n", + "sessions = [generate_session() for _ in range(100)]\n", + "env.env.close()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#show video\n", + "from IPython.display import HTML\n", + "import os\n", + "\n", + "video_names = list(filter(lambda s:s.endswith(\".mp4\"),os.listdir(\"./videos/\")))\n", + "\n", + "HTML(\"\"\"\n", + "\n", + "\"\"\".format(\"./videos/\"+video_names[-1])) #this may or may not be _last_ video. Try other indices" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Submitted to Coursera platform. See results on assignment page!\n" + ] + } + ], + "source": [ + "from submit import submit_cartpole\n", + "submit_cartpole(generate_session, '', '')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# That's all, thank you for your attention!\n", + "# Not having enough? There's an actor-critic waiting for you in the honor section.\n", + "# But make sure you've seen the videos first." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.5" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/Practical Reinforcement Learning/Week5_policy_based/submit.py b/Practical Reinforcement Learning/Week5_policy_based/submit.py new file mode 100644 index 0000000..553e552 --- /dev/null +++ b/Practical Reinforcement Learning/Week5_policy_based/submit.py @@ -0,0 +1,20 @@ +import sys +import numpy as np +sys.path.append("..") +import grading + + +def submit_cartpole(generate_session, email, token): + sessions = [generate_session() for _ in range(100)] + session_rewards = np.array(sessions) + grader = grading.Grader("oyT3Bt7yEeeQvhJmhysb5g") + grader.set_answer("7QKmA", int(np.mean(session_rewards))) + grader.submit(email, token) + + +def submit_kungfu(agent, env, evaluate, email, token): + sessions = [evaluate(agent=agent, env=env, n_games=1) for _ in range(100)] + session_rewards = np.array(sessions) + grader = grading.Grader("6sPnVCn6EeieSRL7rCBNJA") + grader.set_answer("HhNVX", 100*int(np.mean(session_rewards))) + grader.submit(email, token)