-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
299450f
commit 81e6ba1
Showing
98 changed files
with
129,995 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "7d1fa20c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os, sys\n", | ||
"import gc\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "32bb66c4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"sys.path.append('../')\n", | ||
"import torch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "c8c767df", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from get_algos import get_all_algos , run_experiment ,call_paths, plot_mean" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "c9fc3531", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import gym\n", | ||
"# from gym.wrappers import Monitor\n", | ||
"from stable_baselines3 import PPO, A2C, DDPG, DQN, SAC, TD3\n", | ||
"from stable_baselines3.common.vec_env import VecFrameStack,VecVideoRecorder\n", | ||
"from stable_baselines3.common.evaluation import evaluate_policy\n", | ||
"from stable_baselines3.common.env_util import make_atari_env\n", | ||
"from stable_baselines3.common.monitor import Monitor" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "b1e1b76a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"env_name='Breakout-v0'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "54c5cdd5", | ||
"metadata": {}, | ||
"source": [ | ||
"### 2. Testing\n", | ||
"<ol>\n", | ||
" <li>Create the environment</li>\n", | ||
" <li>For each algorithm:</li>\n", | ||
" <ol><li>Load the model</li>\n", | ||
" <li>Evaluate the model for 5 sample iterations </li>\n", | ||
" <li>Output the score for each algo and each iteration</li>\n", | ||
" <li>Save renders of the output </li>\n", | ||
" </ol>\n", | ||
"</ol> " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "65fae03b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['PPO', 'A2C', 'DQN']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"algo_list=get_all_algos(gym.make(env_name))\n", | ||
"print(algo_list)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "7e8fab60", | ||
"metadata": {}, | ||
"source": [ | ||
"### 3. Testing" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"id": "972110da", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def testing_model(algo_list, env_name,n_steps,vid_length):\n", | ||
" total_rewards={}\n", | ||
" for algo_name in reversed(algo_list):\n", | ||
" algo_rewards=[]\n", | ||
" device='cpu' if algo_name=='DQN' else 'cuda'\n", | ||
" log_path, render_path, model_path=call_paths(algo_name,env_name,n_steps)\n", | ||
" env=VecVideoRecorder(VecFrameStack(make_atari_env(env_name,monitor_dir=render_path),n_stack=6),render_path,record_video_trigger=lambda step: step == 0,video_length=vid_length)\n", | ||
" model_name=env_name+\"_\"+algo_name+\"_model\"\n", | ||
" model=eval(algo_name).load(os.path.join(model_path, model_name),env,device=device)\n", | ||
" for i in range(5):\n", | ||
" state=env.reset()\n", | ||
" epi_rewards=0\n", | ||
" while True:\n", | ||
" action,_=model.predict(state)\n", | ||
" state,reward,done,_=env.step(action)\n", | ||
" epi_rewards+=reward[0]\n", | ||
" env.render()\n", | ||
" if done:\n", | ||
" algo_rewards.append(epi_rewards)\n", | ||
" break\n", | ||
" total_rewards[algo_name]=algo_rewards\n", | ||
" del model\n", | ||
" env.close()\n", | ||
" return total_rewards" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "2f2d2d60", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Wrapping the env in a VecTransposeImage.\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"E:\\Anaconda\\envs\\ai_gym\\lib\\site-packages\\stable_baselines3\\common\\buffers.py:229: UserWarning: This system does not have apparently enough memory to store the complete replay buffer 84.69GB > 11.00GB\n", | ||
" \"This system does not have apparently enough memory to store the complete \"\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Wrapping the env in a VecTransposeImage.\n", | ||
"Wrapping the env in a VecTransposeImage.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"test=total_rewards=testing_model(algo_list, env_name,n_steps=200000,vid_length=1000)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"id": "63f6b45f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'DQN': [0.0, 0.0, 0.0, 5.0, 0.0],\n", | ||
" 'A2C': [4.0, 0.0, 0.0, 0.0, 0.0],\n", | ||
" 'PPO': [3.0, 0.0, 0.0, 3.0, 0.0]}" | ||
] | ||
}, | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"test" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "870561de", | ||
"metadata": {}, | ||
"source": [ | ||
"We see that PPO seems to consistently provide rewards with the highest average reward. Hence, we will use PPO as the algorithm for subsequent training of <b> 5,000,000 </b>steps" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "13d01daf", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.0" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.