diff --git a/.gitignore b/.gitignore index 04efac8a..4fbf542f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,5 @@ scripts/keras-rl/rl build/ *~ *.log -*.h5f *# *.pyc \ No newline at end of file diff --git a/models/example_actor.h5f b/models/example_actor.h5f new file mode 100644 index 00000000..37ced1c3 Binary files /dev/null and b/models/example_actor.h5f differ diff --git a/models/example_critic.h5f b/models/example_critic.h5f new file mode 100644 index 00000000..0e73f4dd Binary files /dev/null and b/models/example_critic.h5f differ diff --git a/scripts/train.arm.ipynb b/scripts/train.arm.ipynb new file mode 100644 index 00000000..c3f91f6a --- /dev/null +++ b/scripts/train.arm.ipynb @@ -0,0 +1,1126 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "# Derived from keras-rl\n", + "import opensim as osim\n", + "import numpy as np\n", + "import sys\n", + "\n", + "from keras.models import Sequential, Model\n", + "from keras.layers import Dense, Activation, Flatten, Input, concatenate\n", + "from keras.optimizers import Adam\n", + "\n", + "import numpy as np\n", + "\n", + "from rl.agents import DDPGAgent\n", + "from rl.memory import SequentialMemory\n", + "from rl.random import OrnsteinUhlenbeckProcess\n", + "\n", + "from osim.env.arm import ArmEnv\n", + "\n", + "from keras.optimizers import RMSprop\n", + "\n", + "import argparse\n", + "import math" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Load walking environment\n", + "env = ArmEnv(True)\n", + "env.reset()\n", + "\n", + "# Total number of steps in training\n", + "nallsteps = 10000" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "_________________________________________________________________\n", + "Layer (type) Output Shape Param # \n", + "=================================================================\n", + "flatten_1 (Flatten) (None, 14) 0 \n", + "_________________________________________________________________\n", + "dense_1 (Dense) (None, 32) 480 \n", + "_________________________________________________________________\n", + "activation_1 (Activation) (None, 32) 0 \n", + "_________________________________________________________________\n", + "dense_2 (Dense) (None, 32) 1056 \n", + "_________________________________________________________________\n", + "activation_2 (Activation) (None, 32) 0 \n", + "_________________________________________________________________\n", + "dense_3 (Dense) (None, 32) 1056 \n", + "_________________________________________________________________\n", + "activation_3 (Activation) (None, 32) 0 \n", + "_________________________________________________________________\n", + "dense_4 (Dense) (None, 6) 198 \n", + "_________________________________________________________________\n", + "activation_4 (Activation) (None, 6) 0 \n", + "=================================================================\n", + "Total params: 2,790.0\n", + "Trainable params: 2,790.0\n", + "Non-trainable params: 0.0\n", + "_________________________________________________________________\n", + "None\n" + ] + } + ], + "source": [ + "# Create networks for DDPG\n", + "# Next, we build a very simple model.\n", + "actor = Sequential()\n", + "actor.add(Flatten(input_shape=(1,) + env.observation_space.shape))\n", + "actor.add(Dense(32))\n", + "actor.add(Activation('relu'))\n", + "actor.add(Dense(32))\n", + "actor.add(Activation('relu'))\n", + "actor.add(Dense(32))\n", + "actor.add(Activation('relu'))\n", + "actor.add(Dense(nb_actions))\n", + "actor.add(Activation('sigmoid'))\n", + "print(actor.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "____________________________________________________________________________________________________\n", + "Layer (type) Output Shape Param # Connected to \n", + "====================================================================================================\n", + "observation_input (InputLayer) (None, 1, 14) 0 \n", + "____________________________________________________________________________________________________\n", + "action_input (InputLayer) (None, 6) 0 \n", + "____________________________________________________________________________________________________\n", + "flatten_2 (Flatten) (None, 14) 0 \n", + "____________________________________________________________________________________________________\n", + "concatenate_1 (Concatenate) (None, 20) 0 \n", + "____________________________________________________________________________________________________\n", + "dense_5 (Dense) (None, 64) 1344 \n", + "____________________________________________________________________________________________________\n", + "activation_5 (Activation) (None, 64) 0 \n", + "____________________________________________________________________________________________________\n", + "dense_6 (Dense) (None, 64) 4160 \n", + "____________________________________________________________________________________________________\n", + "activation_6 (Activation) (None, 64) 0 \n", + "____________________________________________________________________________________________________\n", + "dense_7 (Dense) (None, 64) 4160 \n", + "____________________________________________________________________________________________________\n", + "activation_7 (Activation) (None, 64) 0 \n", + "____________________________________________________________________________________________________\n", + "dense_8 (Dense) (None, 1) 65 \n", + "____________________________________________________________________________________________________\n", + "activation_8 (Activation) (None, 1) 0 \n", + "====================================================================================================\n", + "Total params: 9,729.0\n", + "Trainable params: 9,729.0\n", + "Non-trainable params: 0.0\n", + "____________________________________________________________________________________________________\n", + "None\n" + ] + } + ], + "source": [ + "action_input = Input(shape=(nb_actions,), name='action_input')\n", + "observation_input = Input(shape=(1,) + env.observation_space.shape, name='observation_input')\n", + "flattened_observation = Flatten()(observation_input)\n", + "x = concatenate([action_input, flattened_observation])\n", + "x = Dense(64)(x)\n", + "x = Activation('relu')(x)\n", + "x = Dense(64)(x)\n", + "x = Activation('relu')(x)\n", + "x = Dense(64)(x)\n", + "x = Activation('relu')(x)\n", + "x = Dense(1)(x)\n", + "x = Activation('linear')(x)\n", + "critic = Model(inputs=[action_input, observation_input], outputs=x)\n", + "print(critic.summary())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# Set up the agent for training\n", + "memory = SequentialMemory(limit=100000, window_length=1)\n", + "random_process = OrnsteinUhlenbeckProcess(theta=.15, mu=0., sigma=.2, size=env.noutput)\n", + "agent = DDPGAgent(nb_actions=nb_actions, actor=actor, critic=critic, critic_action_input=action_input,\n", + " memory=memory, nb_steps_warmup_critic=100, nb_steps_warmup_actor=100,\n", + " random_process=random_process, gamma=.99, target_model_update=1e-3,\n", + " delta_clip=1.)\n", + "agent.compile(Adam(lr=.001, clipnorm=1.), metrics=['mae'])" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training for 10000 steps ...\n", + "Interval 1 (0 steps performed)\n", + "\n", + "Distance: 1.462456\n", + "True positions: (-0.775025,-0.675188)\n", + "Reached: (-1.566736,-0.004444)\n", + " 100/10000 [..............................] - ETA: 182s - reward: -0.7521\n", + "Distance: 0.658063\n", + "True positions: (0.149549,-0.228965)\n", + "Reached: (-0.406994,-0.330485)\n", + " 200/10000 [..............................] - ETA: 194s - reward: -0.7154\n", + "Distance: 1.648409\n", + "True positions: (-0.436057,-0.518591)\n", + "Reached: (-1.567838,-0.001964)\n", + " 300/10000 [..............................] - ETA: 240s - reward: -0.7834\n", + "Distance: 1.039672\n", + "True positions: (0.041094,-0.158712)\n", + "Reached: (-0.781505,-0.375786)\n", + " 400/10000 [>.............................] - ETA: 227s - reward: -0.7174\n", + "Distance: 1.555081\n", + "True positions: (-0.456883,-0.446139)\n", + "Reached: (-1.567824,-0.001999)\n", + " 500/10000 [>.............................] - ETA: 229s - reward: -0.7441\n", + "Distance: 0.472841\n", + "True positions: (-0.756667,-0.663398)\n", + "Reached: (-0.869688,-0.303578)\n", + " 600/10000 [>.............................] - ETA: 218s - reward: -0.7246\n", + "Distance: 2.821004\n", + "True positions: (0.292063,-0.962931)\n", + "Reached: (-1.567876,-0.001867)\n", + " 700/10000 [=>............................] - ETA: 284s - reward: -0.9930\n", + "Distance: 6.004362\n", + "True positions: (0.182118,-0.878027)\n", + "Reached: (-2.981467,2.006734)\n", + " 800/10000 [=>............................] - ETA: 410s - reward: -1.2518\n", + "Distance: 1.402935\n", + "True positions: (-0.306940,-0.144541)\n", + "Reached: (-1.567689,-0.002355)\n", + " 900/10000 [=>............................] - ETA: 383s - reward: -1.2151\n", + "Distance: 0.715121\n", + "True positions: (-0.392536,-0.734260)\n", + "Reached: (-0.638096,-0.264699)\n", + " 1000/10000 [==>...........................] - ETA: 358s - reward: -1.1485\n", + "Distance: 2.128215\n", + "True positions: (0.013623,-0.559047)\n", + "Reached: (-1.565007,-0.009462)\n", + " 1100/10000 [==>...........................] - ETA: 343s - reward: -1.0932\n", + "Distance: 0.349503\n", + "True positions: (-0.607300,-0.538477)\n", + "Reached: (-0.760894,-0.342568)\n", + " 1200/10000 [==>...........................] - ETA: 325s - reward: -1.0290\n", + "Distance: 0.882774\n", + "True positions: (-0.736318,-0.063696)\n", + "Reached: (-1.564964,-0.009568)\n", + " 1300/10000 [==>...........................] - ETA: 309s - reward: -0.9823\n", + "Distance: 0.275539\n", + "True positions: (-0.601051,-0.504647)\n", + "Reached: (-0.724629,-0.352686)\n", + " 1400/10000 [===>..........................] - ETA: 294s - reward: -0.9902\n", + "Distance: 2.383734\n", + "True positions: (-0.139049,-0.967167)\n", + "Reached: (-1.565025,-0.009409)\n", + " 1500/10000 [===>..........................] - ETA: 284s - reward: -0.9737\n", + "Distance: 0.669710\n", + "True positions: (-0.518046,-0.836900)\n", + "Reached: (-0.720636,-0.369780)\n", + " 1600/10000 [===>..........................] - ETA: 273s - reward: -0.9865\n", + "Distance: 0.951457\n", + "True positions: (-1.066065,-0.461889)\n", + "Reached: (-1.565032,-0.009399)\n", + " 1700/10000 [====>.........................] - ETA: 269s - reward: -0.9623\n", + "Distance: 0.505641\n", + "True positions: (-0.953900,-0.051887)\n", + "Reached: (-0.723666,-0.327293)\n", + " 1800/10000 [====>.........................] - ETA: 257s - reward: -0.9479\n", + "Distance: 1.750347\n", + "True positions: (-0.548398,-0.743257)\n", + "Reached: (-1.564991,-0.009502)\n", + " 1900/10000 [====>.........................] - ETA: 253s - reward: -0.9339\n", + "Distance: 0.481539\n", + "True positions: (-0.430554,-0.645165)\n", + "Reached: (-0.635168,-0.368239)\n", + " 2000/10000 [=====>........................] - ETA: 245s - reward: -0.9211\n", + "Distance: 1.393284\n", + "True positions: (-1.066455,-0.904048)\n", + "Reached: (-1.565046,-0.009355)\n", + " 2100/10000 [=====>........................] - ETA: 243s - reward: -0.9058 ETA: 244s - reward: -\n", + "Distance: 0.498258\n", + "True positions: (-0.635067,-0.074559)\n", + "Reached: (-0.292713,-0.230464)\n", + " 2200/10000 [=====>........................] - ETA: 235s - reward: -0.8986\n", + "Distance: 1.005809\n", + "True positions: (-1.199449,-0.649655)\n", + "Reached: (-1.565022,-0.009419)\n", + " 2300/10000 [=====>........................] - ETA: 233s - reward: -0.9153\n", + "Distance: 0.890960\n", + "True positions: (-0.646296,-0.391640)\n", + "Reached: (0.227098,-0.409206)\n", + " 2400/10000 [======>.......................] - ETA: 225s - reward: -0.9060\n", + "Distance: 1.900609\n", + "True positions: (-0.374564,-0.719667)\n", + "Reached: (-1.564995,-0.009489)\n", + " 2500/10000 [======>.......................] - ETA: 220s - reward: -0.9286\n", + "Distance: 1.234273\n", + "True positions: (-1.065994,-0.339998)\n", + "Reached: (0.049915,-0.458362)\n", + " 2600/10000 [======>.......................] - ETA: 213s - reward: -0.9579 ETA: 213s - rew\n", + "Distance: 1.327021\n", + "True positions: (-0.473797,-0.245345)\n", + "Reached: (-1.564986,-0.009512)\n", + " 2700/10000 [=======>......................] - ETA: 212s - reward: -0.9875\n", + "Distance: 0.841418\n", + "True positions: (-0.926866,-0.719504)\n", + "Reached: (-0.116538,-0.688414)\n", + " 2800/10000 [=======>......................] - ETA: 219s - reward: -1.0514\n", + "Distance: 1.757212\n", + "True positions: (0.030436,-0.171236)\n", + "Reached: (-1.565006,-0.009466)\n", + " 2900/10000 [=======>......................] - ETA: 214s - reward: -1.0487\n", + "Distance: 1.179165\n", + "True positions: (-0.744743,-0.224667)\n", + "Reached: (0.294478,-0.364611)\n", + " 3000/10000 [========>.....................] - ETA: 210s - reward: -1.0467\n", + "Distance: 1.796911\n", + "True positions: (0.103846,-0.137347)\n", + "Reached: (-1.565053,-0.009335)\n", + " 3100/10000 [========>.....................] - ETA: 206s - reward: -1.0433\n", + "Distance: 0.843150\n", + "True positions: (-0.042112,-0.536499)\n", + "Reached: (-0.725420,-0.376656)\n", + " 3200/10000 [========>.....................] - ETA: 201s - reward: -1.0315\n", + "Distance: 2.481492\n", + "True positions: (0.142837,-0.782987)\n", + "Reached: (-1.565039,-0.009370)\n", + " 3300/10000 [========>.....................] - ETA: 196s - reward: -1.0219\n", + "Distance: 0.979552\n", + "True positions: (-0.343665,-0.708980)\n", + "Reached: (0.325115,-1.019753)\n", + " 3400/10000 [=========>....................] - ETA: 191s - reward: -1.0569\n", + "Distance: 2.097808\n", + "True positions: (0.166089,-0.376127)\n", + "Reached: (-1.565020,-0.009428)\n", + " 3500/10000 [=========>....................] - ETA: 188s - reward: -1.0602\n", + "Distance: 1.075332\n", + "True positions: (0.184927,-0.633335)\n", + "Reached: (-0.656767,-0.399697)\n", + " 3600/10000 [=========>....................] - ETA: 183s - reward: -1.0522\n", + "Distance: 1.038497\n", + "True positions: (-0.840122,-0.309307)\n", + "Reached: (-1.568703,0.000609)\n", + " 3700/10000 [==========>...................] - ETA: 179s - reward: -1.0507\n", + "Distance: 1.148837\n", + "True positions: (-0.616768,-0.256357)\n", + "Reached: (-1.524102,-0.497861)\n", + " 3800/10000 [==========>...................] - ETA: 175s - reward: -1.0394\n", + "Distance: 1.051871\n", + "True positions: (-0.859114,-0.355356)\n", + "Reached: (-1.565028,-0.009399)\n", + " 3900/10000 [==========>...................] - ETA: 172s - reward: -1.0386\n", + "Distance: 0.832561\n", + "True positions: (0.050397,-0.508607)\n", + "Reached: (-0.627451,-0.353895)\n", + " 4000/10000 [===========>..................] - ETA: 169s - reward: -1.0464\n", + "Distance: 1.590063\n", + "True positions: (-0.732395,-0.766896)\n", + "Reached: (-1.565010,-0.009449)\n", + " 4100/10000 [===========>..................] - ETA: 165s - reward: -1.0481\n", + "Distance: 0.993980\n", + "True positions: (-0.675373,-0.852840)\n", + "Reached: (0.168685,-1.002762)\n", + " 4200/10000 [===========>..................] - ETA: 160s - reward: -1.0723\n", + "Distance: 2.375452\n", + "True positions: (0.144393,-0.675761)\n", + "Reached: (-1.564778,-0.009480)\n", + " 4300/10000 [===========>..................] - ETA: 156s - reward: -1.0675\n", + "Distance: 0.787745\n", + "True positions: (-0.268879,-0.871835)\n", + "Reached: (0.090438,-0.443407)\n", + " 4400/10000 [============>.................] - ETA: 151s - reward: -1.0798\n", + "Distance: 0.989008\n", + "True positions: (-1.138425,-0.571705)\n", + "Reached: (-1.565005,-0.009277)\n", + " 4500/10000 [============>.................] - ETA: 147s - reward: -1.0730\n", + "Distance: 0.501576\n", + "True positions: (0.013873,-0.727211)\n", + "Reached: (-0.186000,-0.425507)\n", + " 4600/10000 [============>.................] - ETA: 147s - reward: -1.0865\n", + "Distance: 1.762822\n", + "True positions: (-0.720930,-0.917786)\n", + "Reached: (-1.567865,-0.001899)\n", + " 4700/10000 [=============>................] - ETA: 145s - reward: -1.0840\n", + "Distance: 1.450326\n", + "True positions: (-0.251281,-0.551979)\n", + "Reached: (0.212165,-1.538859)\n", + " 4800/10000 [=============>................] - ETA: 145s - reward: -1.1120\n", + "Distance: 1.415776\n", + "True positions: (-0.589973,-0.450436)\n", + "Reached: (-1.564786,-0.009473)\n", + " 4900/10000 [=============>................] - ETA: 141s - reward: -1.1080\n", + "Distance: 0.866850\n", + "True positions: (-0.271957,-0.752964)\n", + "Reached: (0.293206,-1.054652)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 5000/10000 [==============>...............] - ETA: 142s - reward: -1.1339\n", + "Distance: 1.857073\n", + "True positions: (-0.531581,-0.833223)\n", + "Reached: (-1.564839,-0.009408)\n", + " 5100/10000 [==============>...............] - ETA: 139s - reward: -1.1433\n", + "Distance: 2.968165\n", + "True positions: (-0.956974,-0.313639)\n", + "Reached: (-0.528744,-2.853575)\n", + " 5200/10000 [==============>...............] - ETA: 135s - reward: -1.1710\n", + "Distance: 1.761903\n", + "True positions: (-0.115508,-0.311346)\n", + "Reached: (-1.567891,-0.001827)\n", + " 5300/10000 [==============>...............] - ETA: 131s - reward: -1.1751\n", + "Distance: 1.845603\n", + "True positions: (-0.982332,-0.100574)\n", + "Reached: (0.112800,-0.851045)\n", + " 5400/10000 [===============>..............] - ETA: 127s - reward: -1.1884\n", + "Distance: 0.942085\n", + "True positions: (-1.059572,-0.446698)\n", + "Reached: (-1.564688,-0.009730)\n", + " 5500/10000 [===============>..............] - ETA: 124s - reward: -1.1822\n", + "Distance: 1.821335\n", + "True positions: (-0.077805,-0.355466)\n", + "Reached: (0.340218,-1.758778)\n", + " 5600/10000 [===============>..............] - ETA: 121s - reward: -1.1978\n", + "Distance: 1.848909\n", + "True positions: (-0.511823,-0.794789)\n", + "Reached: (-1.567859,-0.001916)\n", + " 5700/10000 [================>.............] - ETA: 117s - reward: -1.1912\n", + "Distance: 0.721368\n", + "True positions: (0.102164,-0.304120)\n", + "Reached: (-0.478851,-0.444473)\n", + " 5800/10000 [================>.............] - ETA: 114s - reward: -1.2048\n", + "Distance: 1.803081\n", + "True positions: (0.067109,-0.180366)\n", + "Reached: (-1.565021,-0.009416)\n", + " 5900/10000 [================>.............] - ETA: 111s - reward: -1.2002 ETA: \n", + "Distance: 0.878687\n", + "True positions: (0.189819,-0.222802)\n", + "Reached: (-0.456033,-0.455637)\n", + " 6000/10000 [=================>............] - ETA: 107s - reward: -1.1920\n", + "Distance: 1.675537\n", + "True positions: (-0.020565,-0.141019)\n", + "Reached: (-1.564724,-0.009641)\n", + " 6100/10000 [=================>............] - ETA: 105s - reward: -1.1997\n", + "Distance: 2.680756\n", + "True positions: (-1.108342,-0.750141)\n", + "Reached: (0.357185,-1.965370)\n", + " 6200/10000 [=================>............] - ETA: 104s - reward: -1.2265\n", + "Distance: 2.177150\n", + "True positions: (0.065758,-0.556331)\n", + "Reached: (-1.564714,-0.009652)\n", + " 6300/10000 [=================>............] - ETA: 101s - reward: -1.2244\n", + "Distance: 2.005363\n", + "True positions: (-0.509213,-0.446309)\n", + "Reached: (0.032482,-1.909977)\n", + " 6400/10000 [==================>...........] - ETA: 98s - reward: -1.2368\n", + "Distance: 1.504596\n", + "True positions: (-0.596605,-0.546096)\n", + "Reached: (-1.564744,-0.009639)\n", + " 6500/10000 [==================>...........] - ETA: 95s - reward: -1.2450\n", + "Distance: 3.206485\n", + "True positions: (-1.160324,-0.770575)\n", + "Reached: (-0.568856,-3.385592)\n", + " 6600/10000 [==================>...........] - ETA: 92s - reward: -1.2426\n", + "Distance: 1.569507\n", + "True positions: (-0.702374,-0.716388)\n", + "Reached: (-1.564991,-0.009498)\n", + " 6700/10000 [===================>..........] - ETA: 89s - reward: -1.2381\n", + "Distance: 2.123637\n", + "True positions: (-0.558666,-0.286198)\n", + "Reached: (0.466372,-1.384796)\n", + " 6800/10000 [===================>..........] - ETA: 88s - reward: -1.2681\n", + "Distance: 1.763844\n", + "True positions: (-0.668754,-0.877421)\n", + "Reached: (-1.564828,-0.009651)\n", + " 6900/10000 [===================>..........] - ETA: 85s - reward: -1.2608\n", + "Distance: 0.615627\n", + "True positions: (0.093075,-0.617224)\n", + "Reached: (-0.373101,-0.467773)\n", + " 7000/10000 [====================>.........] - ETA: 81s - reward: -1.2507\n", + "Distance: 0.588110\n", + "True positions: (-1.065392,-0.098003)\n", + "Reached: (-1.564993,-0.009495)\n", + " 7100/10000 [====================>.........] - ETA: 79s - reward: -1.2461\n", + "Distance: 0.839994\n", + "True positions: (-0.186739,-0.862378)\n", + "Reached: (-0.536173,-0.371818)\n", + " 7200/10000 [====================>.........] - ETA: 75s - reward: -1.2452\n", + "Distance: 2.204858\n", + "True positions: (0.242761,-0.406999)\n", + "Reached: (-1.564731,-0.009632)\n", + " 7300/10000 [====================>.........] - ETA: 72s - reward: -1.2374\n", + "Distance: 0.514505\n", + "True positions: (0.092020,-0.402023)\n", + "Reached: (-0.397748,-0.377285)\n", + " 7400/10000 [=====================>........] - ETA: 69s - reward: -1.2311\n", + "Distance: 1.350662\n", + "True positions: (-1.177139,-0.972360)\n", + "Reached: (-1.564977,-0.009536)\n", + " 7500/10000 [=====================>........] - ETA: 66s - reward: -1.2269\n", + "Distance: 0.795759\n", + "True positions: (0.168329,-0.075573)\n", + "Reached: (-0.312060,-0.390944)\n", + " 7600/10000 [=====================>........] - ETA: 63s - reward: -1.2179\n", + "Distance: 0.479833\n", + "True positions: (-1.105746,-0.030006)\n", + "Reached: (-1.565014,-0.009441)\n", + " 7700/10000 [======================>.......] - ETA: 60s - reward: -1.2075\n", + "Distance: 0.082770\n", + "True positions: (-0.623754,-0.254800)\n", + "Reached: (-0.650752,-0.310572)\n", + " 7800/10000 [======================>.......] - ETA: 58s - reward: -1.1952\n", + "Distance: 0.857551\n", + "True positions: (-0.765756,-0.068037)\n", + "Reached: (-1.564921,-0.009651)\n", + " 7900/10000 [======================>.......] - ETA: 55s - reward: -1.1863\n", + "Distance: 0.260719\n", + "True positions: (-1.101818,-0.300604)\n", + "Reached: (-0.870004,-0.329510)\n", + " 8000/10000 [=======================>......] - ETA: 52s - reward: -1.1766\n", + "Distance: 2.515288\n", + "True positions: (0.261544,-0.698659)\n", + "Reached: (-1.564722,-0.009638)\n", + " 8100/10000 [=======================>......] - ETA: 49s - reward: -1.1734\n", + "Distance: 0.883142\n", + "True positions: (0.052401,-0.062017)\n", + "Reached: (-0.490915,-0.401842)\n", + " 8200/10000 [=======================>......] - ETA: 46s - reward: -1.1614- ETA:\n", + "Distance: 1.940551\n", + "True positions: (-0.036622,-0.422005)\n", + "Reached: (-1.564747,-0.009580)\n", + " 8300/10000 [=======================>......] - ETA: 44s - reward: -1.1574\n", + "Distance: 1.039889\n", + "True positions: (-1.178959,-0.767795)\n", + "Reached: (-0.604128,-0.302737)\n", + " 8400/10000 [========================>.....] - ETA: 41s - reward: -1.1544\n", + "Distance: 1.795205\n", + "True positions: (-0.048117,-0.287719)\n", + "Reached: (-1.565021,-0.009418)\n", + " 8500/10000 [========================>.....] - ETA: 39s - reward: -1.1483\n", + "Distance: 0.432170\n", + "True positions: (-0.586678,-0.025857)\n", + "Reached: (-0.387922,-0.259272)\n", + " 8600/10000 [========================>.....] - ETA: 36s - reward: -1.1416\n", + "Distance: 1.348519\n", + "True positions: (-0.411534,-0.204388)\n", + "Reached: (-1.565039,-0.009374)\n", + " 8700/10000 [=========================>....] - ETA: 33s - reward: -1.1343\n", + "Distance: 0.086120\n", + "True positions: (-0.138389,-0.450312)\n", + "Reached: (-0.109487,-0.393095)\n", + " 8800/10000 [=========================>....] - ETA: 30s - reward: -1.1290\n", + "Distance: 1.118859\n", + "True positions: (-0.564185,-0.127702)\n", + "Reached: (-1.564895,-0.009553)\n", + " 8900/10000 [=========================>....] - ETA: 28s - reward: -1.1204\n", + "Distance: 0.352498\n", + "True positions: (-0.590085,-0.337947)\n", + "Reached: (-0.295942,-0.279592)\n", + " 9000/10000 [==========================>...] - ETA: 25s - reward: -1.1158\n", + "Distance: 1.242576\n", + "True positions: (-0.843296,-0.530437)\n", + "Reached: (-1.564850,-0.009415)\n", + " 9100/10000 [==========================>...] - ETA: 23s - reward: -1.1094\n", + "Distance: 0.461542\n", + "True positions: (-0.591572,-0.519359)\n", + "Reached: (-0.378585,-0.270804)\n", + " 9200/10000 [==========================>...] - ETA: 20s - reward: -1.1006\n", + "Distance: 2.083646\n", + "True positions: (-0.432670,-0.950337)\n", + "Reached: (-1.567867,-0.001888)\n", + " 9300/10000 [==========================>...] - ETA: 17s - reward: -1.1001- ETA: 18s - r\n", + "Distance: 1.068966\n", + "True positions: (-1.096158,-0.979813)\n", + "Reached: (-0.694457,-0.312549)\n", + " 9400/10000 [===========================>..] - ETA: 15s - reward: -1.1003\n", + "Distance: 1.887375\n", + "True positions: (-0.400751,-0.732703)\n", + "Reached: (-1.564902,-0.009479)\n", + " 9500/10000 [===========================>..] - ETA: 12s - reward: -1.0954\n", + "Distance: 0.713027\n", + "True positions: (-1.051577,-0.679967)\n", + "Reached: (-0.737376,-0.281141)\n", + " 9600/10000 [===========================>..] - ETA: 10s - reward: -1.0909\n", + "Distance: 0.683912\n", + "True positions: (-1.168939,-0.297389)\n", + "Reached: (-1.564983,-0.009521)\n", + " 9700/10000 [============================>.] - ETA: 7s - reward: -1.0853\n", + "Distance: 0.197110\n", + "True positions: (-0.192663,-0.181194)\n", + "Reached: (-0.061201,-0.246843)\n", + " 9800/10000 [============================>.] - ETA: 5s - reward: -1.0836\n", + "Distance: 1.406991\n", + "True positions: (-1.043423,-0.895327)\n", + "Reached: (-1.564724,-0.009636)\n", + " 9900/10000 [============================>.] - ETA: 2s - reward: -1.0811\n", + "Distance: 0.958584\n", + "True positions: (-0.488645,-0.959397)\n", + "Reached: (-0.172702,-0.316755)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "10000/10000 [==============================] - 250s - reward: -1.0771 \n", + "done, took 250.614 seconds\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Okay, now it's time to learn something! We visualize the training here for show, but this\n", + "# slows down training quite a lot. You can always safely abort the training prematurely using\n", + "# Ctrl + C.\n", + "agent.fit(env, nb_steps=50000, visualize=False, verbose=1, nb_max_episode_steps=200, log_interval=10000)\n", + "# After training is done, we save the final weights.\n", + "# agent.save_weights(args.model, overwrite=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing for 5 episodes ...\n", + "\n", + "Distance: 1.116082\n", + "True positions: (-0.803720,-0.364419)\n", + "Reached: (-1.564910,-0.009527)\n", + "\n", + "Distance: 0.168364\n", + "True positions: (-0.610786,-0.272151)\n", + "Reached: (-0.471284,-0.243289)\n", + "\n", + "Distance: 0.598890\n", + "True positions: (-1.192979,-0.217319)\n", + "Reached: (-0.700967,-0.324197)\n", + "\n", + "Distance: 0.647045\n", + "True positions: (-0.388098,-0.196909)\n", + "Reached: (-0.825401,-0.406651)\n", + "\n", + "Distance: 0.215215\n", + "True positions: (-0.556208,-0.379026)\n", + "Reached: (-0.399488,-0.320530)\n", + "\n", + "Distance: 0.670461\n", + "True positions: (-0.915811,-0.763724)\n", + "Reached: (-0.731732,-0.277342)\n", + "\n", + "Distance: 0.515910\n", + "True positions: (-0.753540,-0.108279)\n", + "Reached: (-0.991832,-0.385896)\n", + "\n", + "Distance: 0.670115\n", + "True positions: (-0.180781,-0.773335)\n", + "Reached: (-0.473330,-0.395769)\n", + "\n", + "Distance: 0.649070\n", + "True positions: (-0.451092,-0.961598)\n", + "Reached: (-0.442753,-0.320867)\n", + "\n", + "Distance: 0.796233\n", + "True positions: (-0.516566,-0.052818)\n", + "Reached: (-1.012637,-0.352980)\n", + "Episode 1: reward: -498.678, steps: 1000\n", + "\n", + "Distance: 2.459794\n", + "True positions: (0.169571,-0.735097)\n", + "Reached: (-1.564734,-0.009608)\n", + "\n", + "Distance: 0.801829\n", + "True positions: (0.034648,-0.287289)\n", + "Reached: (-0.604485,-0.449986)\n", + "\n", + "Distance: 1.404096\n", + "True positions: (0.245788,-0.097899)\n", + "Reached: (-0.784187,-0.472021)\n", + "\n", + "Distance: 0.715185\n", + "True positions: (-0.469800,-0.913573)\n", + "Reached: (-0.621891,-0.350479)\n", + "\n", + "Distance: 0.733306\n", + "True positions: (-0.847274,-0.955848)\n", + "Reached: (-0.770814,-0.299002)\n", + "\n", + "Distance: 0.230156\n", + "True positions: (-1.037488,-0.432727)\n", + "Reached: (-0.920186,-0.319873)\n", + "\n", + "Distance: 0.440740\n", + "True positions: (-0.844373,-0.608061)\n", + "Reached: (-0.699958,-0.311736)\n", + "\n", + "Distance: 0.924492\n", + "True positions: (-0.680619,-0.867557)\n", + "Reached: (-0.346047,-0.277638)\n", + "\n", + "Distance: 0.752919\n", + "True positions: (0.044109,-0.763698)\n", + "Reached: (-0.410758,-0.465646)\n", + "\n", + "Distance: 0.672073\n", + "True positions: (0.018677,-0.707216)\n", + "Reached: (-0.405336,-0.459156)\n", + "Episode 2: reward: -667.943, steps: 1000\n", + "\n", + "Distance: 0.668439\n", + "True positions: (-1.025154,-0.138046)\n", + "Reached: (-1.565007,-0.009460)\n", + "\n", + "Distance: 0.159730\n", + "True positions: (-0.277725,-0.443307)\n", + "Reached: (-0.320849,-0.326701)\n", + "\n", + "Distance: 0.334421\n", + "True positions: (-0.404840,-0.569673)\n", + "Reached: (-0.528528,-0.358939)\n", + "\n", + "Distance: 0.046731\n", + "True positions: (-0.976947,-0.312122)\n", + "Reached: (-0.963084,-0.344989)\n", + "\n", + "Distance: 1.008901\n", + "True positions: (0.062443,-0.827527)\n", + "Reached: (-0.534605,-0.415674)\n", + "\n", + "Distance: 0.773373\n", + "True positions: (-0.448384,-0.777123)\n", + "Reached: (-0.155566,-0.296568)\n", + "\n", + "Distance: 0.667422\n", + "True positions: (-0.435711,-0.786211)\n", + "Reached: (-0.709425,-0.392504)\n", + "\n", + "Distance: 0.352981\n", + "True positions: (-0.489889,-0.498809)\n", + "Reached: (-0.328137,-0.307581)\n", + "\n", + "Distance: 1.385501\n", + "True positions: (0.291133,-0.380907)\n", + "Reached: (-0.616283,-0.858993)\n", + "\n", + "Distance: 1.216234\n", + "True positions: (0.236555,-0.946088)\n", + "Reached: (-0.629405,-1.296363)\n", + "Episode 3: reward: -624.535, steps: 1000\n", + "\n", + "Distance: 1.111359\n", + "True positions: (-1.148266,-0.704407)\n", + "Reached: (-1.564798,-0.009579)\n", + "\n", + "Distance: 0.671940\n", + "True positions: (-0.648418,-0.711167)\n", + "Reached: (-0.402534,-0.285110)\n", + "\n", + "Distance: 0.633859\n", + "True positions: (0.063465,-0.561639)\n", + "Reached: (-0.542229,-0.589804)\n", + "\n", + "Distance: 0.290913\n", + "True positions: (-0.392989,-0.376063)\n", + "Reached: (-0.660109,-0.352271)\n", + "\n", + "Distance: 0.619160\n", + "True positions: (-0.591837,-0.926301)\n", + "Reached: (-0.604321,-0.319625)\n", + "\n", + "Distance: 1.041864\n", + "True positions: (0.182938,-0.081559)\n", + "Reached: (-0.598350,-0.342135)\n", + "\n", + "Distance: 0.505638\n", + "True positions: (-0.146238,-0.345219)\n", + "Reached: (-0.611294,-0.385801)\n", + "\n", + "Distance: 0.968334\n", + "True positions: (0.006998,-0.125341)\n", + "Reached: (-0.574171,-0.512506)\n", + "\n", + "Distance: 0.427463\n", + "True positions: (-0.453564,-0.068854)\n", + "Reached: (-0.596623,-0.353258)\n", + "\n", + "Distance: 0.893231\n", + "True positions: (-0.085683,-0.671953)\n", + "Reached: (-0.769541,-0.462580)\n", + "Episode 4: reward: -694.748, steps: 1000\n", + "\n", + "Distance: 1.864559\n", + "True positions: (-0.036882,-0.346283)\n", + "Reached: (-1.564756,-0.009598)\n", + "\n", + "Distance: 0.728857\n", + "True positions: (0.102162,-0.161375)\n", + "Reached: (-0.618213,-0.169857)\n", + "\n", + "Distance: 1.126113\n", + "True positions: (0.082182,-0.128944)\n", + "Reached: (-0.764064,-0.408811)\n", + "\n", + "Distance: 0.738605\n", + "True positions: (-0.795643,-0.996414)\n", + "Reached: (-0.736228,-0.317225)\n", + "\n", + "Distance: 0.794818\n", + "True positions: (-1.127481,-0.786967)\n", + "Reached: (-0.733205,-0.386426)\n", + "\n", + "Distance: 0.468622\n", + "True positions: (-0.229081,-0.704277)\n", + "Reached: (-0.308980,-0.315554)\n", + "\n", + "Distance: 0.239779\n", + "True positions: (-1.019121,-0.503951)\n", + "Reached: (-0.882423,-0.400869)\n", + "\n", + "Distance: 0.561899\n", + "True positions: (-0.935049,-0.565972)\n", + "Reached: (-0.662136,-0.276985)\n", + "\n", + "Distance: 0.225996\n", + "True positions: (-1.005050,-0.164818)\n", + "Reached: (-0.954448,-0.340212)\n", + "\n", + "Distance: 0.377026\n", + "True positions: (-1.002465,-0.094672)\n", + "Reached: (-0.830223,-0.299456)\n", + "Episode 5: reward: -621.735, steps: 1000\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# agent.load_weights(args.model)\n", + "# Finally, evaluate our algorithm for 1 episode.\n", + "agent.test(env, nb_episodes=2, visualize=False, nb_max_episode_steps=1000)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Testing for 5 episodes ...\n", + "\n", + "Distance: 1.242484\n", + "True positions: (-0.387711,-0.071820)\n", + "Reached: (-1.565560,-0.007185)\n", + "\n", + "Distance: 0.094699\n", + "True positions: (-0.306534,-0.320054)\n", + "Reached: (-0.279753,-0.387973)\n", + "\n", + "Distance: 0.342258\n", + "True positions: (-0.041413,-0.225383)\n", + "Reached: (-0.127826,-0.481228)\n", + "\n", + "Distance: 0.071426\n", + "True positions: (0.074357,-0.365959)\n", + "Reached: (0.126484,-0.346660)\n", + "\n", + "Distance: 0.273507\n", + "True positions: (-0.254339,-0.147204)\n", + "Reached: (-0.267362,-0.407688)\n", + "\n", + "Distance: 0.170239\n", + "True positions: (-0.658603,-0.784083)\n", + "Reached: (-0.772865,-0.728107)\n", + "\n", + "Distance: 0.219391\n", + "True positions: (0.178447,-0.411043)\n", + "Reached: (0.255614,-0.268819)\n", + "\n", + "Distance: 0.035296\n", + "True positions: (-0.170949,-0.637620)\n", + "Reached: (-0.139976,-0.641944)\n", + "\n", + "Distance: 0.231731\n", + "True positions: (-0.592417,-0.710361)\n", + "Reached: (-0.479023,-0.592024)\n", + "\n", + "Distance: 0.379548\n", + "True positions: (-0.430683,-0.941759)\n", + "Reached: (-0.332039,-0.660855)\n", + "Episode 1: reward: -253.070, steps: 1000\n", + "\n", + "Distance: 1.895049\n", + "True positions: (-0.574247,-0.910968)\n", + "Reached: (-1.565549,-0.007221)\n", + "\n", + "Distance: 0.187712\n", + "True positions: (0.129764,-0.243701)\n", + "Reached: (0.206473,-0.354703)\n", + "\n", + "Distance: 0.208386\n", + "True positions: (-0.950850,-0.633929)\n", + "Reached: (-0.804989,-0.571404)\n", + "\n", + "Distance: 0.271093\n", + "True positions: (-0.720792,-0.843541)\n", + "Reached: (-0.896787,-0.938639)\n", + "\n", + "Distance: 0.123235\n", + "True positions: (-0.540772,-0.374541)\n", + "Reached: (-0.663836,-0.374711)\n", + "\n", + "Distance: 0.330008\n", + "True positions: (0.003994,-0.294434)\n", + "Reached: (-0.172817,-0.447631)\n", + "\n", + "Distance: 0.516125\n", + "True positions: (-0.975207,-0.249345)\n", + "Reached: (-0.580898,-0.371160)\n", + "\n", + "Distance: 0.050081\n", + "True positions: (-0.663840,-0.934810)\n", + "Reached: (-0.680518,-0.968212)\n", + "\n", + "Distance: 0.059912\n", + "True positions: (-0.096554,-0.820200)\n", + "Reached: (-0.037651,-0.819190)\n", + "\n", + "Distance: 0.161258\n", + "True positions: (-0.281086,-0.132806)\n", + "Reached: (-0.300719,-0.274431)\n", + "Episode 2: reward: -347.761, steps: 1000\n", + "\n", + "Distance: 1.268470\n", + "True positions: (-0.911250,-0.624593)\n", + "Reached: (-1.564734,-0.009608)\n", + "\n", + "Distance: 0.141947\n", + "True positions: (-0.180663,-0.596614)\n", + "Reached: (-0.144050,-0.491280)\n", + "\n", + "Distance: 0.058636\n", + "True positions: (-0.264016,-0.857132)\n", + "Reached: (-0.296366,-0.830846)\n", + "\n", + "Distance: 0.426107\n", + "True positions: (-0.735756,-0.008678)\n", + "Reached: (-0.678528,-0.377557)\n", + "\n", + "Distance: 0.374492\n", + "True positions: (-0.006285,-0.065618)\n", + "Reached: (-0.041433,-0.404961)\n", + "\n", + "Distance: 0.113877\n", + "True positions: (0.030428,-0.725052)\n", + "Reached: (0.097678,-0.678424)\n", + "\n", + "Distance: 0.385332\n", + "True positions: (-0.705671,-0.029285)\n", + "Reached: (-0.628504,-0.337451)\n", + "\n", + "Distance: 0.032817\n", + "True positions: (-0.327528,-0.229676)\n", + "Reached: (-0.334632,-0.255389)\n", + "\n", + "Distance: 0.040550\n", + "True positions: (-0.883968,-0.641806)\n", + "Reached: (-0.901986,-0.619274)\n", + "\n", + "Distance: 0.056513\n", + "True positions: (-0.806183,-0.478284)\n", + "Reached: (-0.829101,-0.511879)\n", + "Episode 3: reward: -283.217, steps: 1000\n", + "\n", + "Distance: 0.614035\n", + "True positions: (-0.971813,-0.030721)\n", + "Reached: (-1.564734,-0.009608)\n", + "\n", + "Distance: 0.096513\n", + "True positions: (-0.049614,-0.954414)\n", + "Reached: (0.035289,-0.942804)\n", + "\n", + "Distance: 0.052221\n", + "True positions: (-1.005782,-0.929025)\n", + "Reached: (-0.992384,-0.967848)\n", + "\n", + "Distance: 0.659714\n", + "True positions: (-1.075068,-0.010490)\n", + "Reached: (-0.762229,-0.357366)\n", + "\n", + "Distance: 0.215477\n", + "True positions: (-0.900503,-0.620192)\n", + "Reached: (-0.803456,-0.501762)\n", + "\n", + "Distance: 0.083510\n", + "True positions: (-0.355865,-0.730697)\n", + "Reached: (-0.419280,-0.750792)\n", + "\n", + "Distance: 0.065394\n", + "True positions: (0.018987,-0.339078)\n", + "Reached: (0.037223,-0.291920)\n", + "\n", + "Distance: 0.102855\n", + "True positions: (-0.002792,-0.706756)\n", + "Reached: (0.086531,-0.693224)\n", + "\n", + "Distance: 0.086674\n", + "True positions: (-1.024711,-0.524868)\n", + "Reached: (-1.030540,-0.605714)\n", + "\n", + "Distance: 0.200725\n", + "True positions: (0.136057,-0.605811)\n", + "Reached: (0.235742,-0.504771)\n", + "Episode 4: reward: -335.662, steps: 1000\n", + "\n", + "Distance: 1.932100\n", + "True positions: (0.209470,-0.164254)\n", + "Reached: (-1.565560,-0.007185)\n", + "\n", + "Distance: 0.239532\n", + "True positions: (0.258944,-0.413510)\n", + "Reached: (0.353644,-0.268678)\n", + "\n", + "Distance: 0.046742\n", + "True positions: (-0.128176,-0.951958)\n", + "Reached: (-0.088934,-0.944457)\n", + "\n", + "Distance: 0.153107\n", + "True positions: (-0.697964,-0.421781)\n", + "Reached: (-0.787300,-0.485551)\n", + "\n", + "Distance: 0.003229\n", + "True positions: (-0.136660,-0.346937)\n", + "Reached: (-0.139087,-0.346134)\n", + "\n", + "Distance: 0.168435\n", + "True positions: (0.175476,-0.870530)\n", + "Reached: (0.320923,-0.893518)\n", + "\n", + "Distance: 0.179364\n", + "True positions: (-1.127810,-0.790026)\n", + "Reached: (-0.956237,-0.797817)\n", + "\n", + "Distance: 0.264297\n", + "True positions: (0.237749,-0.140255)\n", + "Reached: (0.335439,-0.306863)\n", + "\n", + "Distance: 0.080070\n", + "True positions: (0.133837,-0.230604)\n", + "Reached: (0.177112,-0.267399)\n", + "\n", + "Distance: 0.316193\n", + "True positions: (-1.191011,-0.135528)\n", + "Reached: (-1.081113,-0.341824)\n", + "Episode 5: reward: -362.485, steps: 1000\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "agent.load_weights(\"../models/example.h5f\")\n", + "# Finally, evaluate our algorithm for 1 episode.\n", + "agent.test(env, nb_episodes=2, visualize=False, nb_max_episode_steps=1000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}