Skip to content

Commit

Permalink
Add one step lookahead function for easy comparison with Value Iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
Sanyam Kapoor committed Feb 19, 2018
1 parent 5334a6f commit 6211e2d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 69 deletions.
42 changes: 16 additions & 26 deletions DP/Policy Evaluation Solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
"cells": [
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"collapsed": true
},
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from IPython.core.debugger import set_trace\n",
"import numpy as np\n",
"import pprint\n",
"import sys\n",
Expand All @@ -18,10 +17,8 @@
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {
"collapsed": true
},
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"pp = pprint.PrettyPrinter(indent=2)\n",
Expand All @@ -30,10 +27,8 @@
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"collapsed": true
},
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def policy_eval(policy, env, discount_factor=1.0, theta=0.00001):\n",
Expand Down Expand Up @@ -76,10 +71,8 @@
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"collapsed": true
},
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"random_policy = np.ones([env.nS, env.nA]) / env.nA\n",
Expand All @@ -88,7 +81,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -98,7 +91,8 @@
"Value Function:\n",
"[ 0. -13.99993529 -19.99990698 -21.99989761 -13.99993529\n",
" -17.9999206 -19.99991379 -19.99991477 -19.99990698 -19.99991379\n",
" -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569 0. ]\n",
" -17.99992725 -13.99994569 -21.99989761 -19.99991477 -13.99994569\n",
" 0. ]\n",
"\n",
"Reshaped Grid Value Function:\n",
"[[ 0. -13.99993529 -19.99990698 -21.99989761]\n",
Expand All @@ -121,10 +115,8 @@
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"collapsed": true
},
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Test: Make sure the evaluated policy is what we expected\n",
Expand All @@ -135,9 +127,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
}
Expand All @@ -158,7 +148,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.4"
}
},
"nbformat": 4,
Expand Down
89 changes: 46 additions & 43 deletions DP/Policy Iteration Solution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
Expand All @@ -19,9 +17,7 @@
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"pp = pprint.PrettyPrinter(indent=2)\n",
Expand All @@ -30,10 +26,8 @@
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"collapsed": true
},
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# Taken from Policy Evaluation Exercise!\n",
Expand Down Expand Up @@ -78,10 +72,8 @@
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"collapsed": true
},
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def policy_improvement(env, policy_eval_fn=policy_eval, discount_factor=1.0):\n",
Expand All @@ -102,6 +94,24 @@
" V is the value function for the optimal policy.\n",
" \n",
" \"\"\"\n",
"\n",
" def one_step_lookahead(state, V):\n",
" \"\"\"\n",
" Helper function to calculate the value for all action in a given state.\n",
" \n",
" Args:\n",
" state: The state to consider (int)\n",
" V: The value to use as an estimator, Vector of length env.nS\n",
" \n",
" Returns:\n",
" A vector of length env.nA containing the expected value of each action.\n",
" \"\"\"\n",
" A = np.zeros(env.nA)\n",
" for a in range(env.nA):\n",
" for prob, next_state, reward, done in env.P[state][a]:\n",
" A[a] += prob * (reward + discount_factor * V[next_state])\n",
" return A\n",
" \n",
" # Start with a random policy\n",
" policy = np.ones([env.nS, env.nA]) / env.nA\n",
" \n",
Expand All @@ -119,10 +129,7 @@
" \n",
" # Find the best action by one-step lookahead\n",
" # Ties are resolved arbitarily\n",
" action_values = np.zeros(env.nA)\n",
" for a in range(env.nA):\n",
" for prob, next_state, reward, done in env.P[s][a]:\n",
" action_values[a] += prob * (reward + discount_factor * V[next_state])\n",
" action_values = one_step_lookahead(s, V)\n",
" best_a = np.argmax(action_values)\n",
" \n",
" # Greedily update the policy\n",
Expand All @@ -137,30 +144,30 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Policy Probability Distribution:\n",
"[[ 1. 0. 0. 0.]\n",
" [ 0. 0. 0. 1.]\n",
" [ 0. 0. 0. 1.]\n",
" [ 0. 0. 1. 0.]\n",
" [ 1. 0. 0. 0.]\n",
" [ 1. 0. 0. 0.]\n",
" [ 1. 0. 0. 0.]\n",
" [ 0. 0. 1. 0.]\n",
" [ 1. 0. 0. 0.]\n",
" [ 1. 0. 0. 0.]\n",
" [ 0. 1. 0. 0.]\n",
" [ 0. 0. 1. 0.]\n",
" [ 1. 0. 0. 0.]\n",
" [ 0. 1. 0. 0.]\n",
" [ 0. 1. 0. 0.]\n",
" [ 1. 0. 0. 0.]]\n",
"[[1. 0. 0. 0.]\n",
" [0. 0. 0. 1.]\n",
" [0. 0. 0. 1.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [0. 0. 1. 0.]\n",
" [1. 0. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [0. 1. 0. 0.]\n",
" [1. 0. 0. 0.]]\n",
"\n",
"Reshaped Grid Policy (0=up, 1=right, 2=down, 3=left):\n",
"[[0 3 3 2]\n",
Expand Down Expand Up @@ -202,10 +209,8 @@
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"collapsed": true
},
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Test the value function\n",
Expand All @@ -216,9 +221,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": []
}
Expand All @@ -239,7 +242,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
"version": "3.6.4"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 6211e2d

Please sign in to comment.