From 69e0242f4db7f56a7ba094f44ba99bb1faeb5513 Mon Sep 17 00:00:00 2001 From: Georg Martius Date: Fri, 21 Dec 2018 14:42:16 +0100 Subject: [PATCH] hand-crafted player added puck has maximal velocity rackets are thicker and round to avoid penetration --- Laser-Hockey-Env.ipynb | 349 ++++++++++++++++++++++++++++++++++++----- laser_hockey_env.py | 73 ++++++++- 2 files changed, 378 insertions(+), 44 deletions(-) diff --git a/Laser-Hockey-Env.ipynb b/Laser-Hockey-Env.ipynb index 7d88923..037f667 100644 --- a/Laser-Hockey-Env.ipynb +++ b/Laser-Hockey-Env.ipynb @@ -2,11 +2,11 @@ "cells": [ { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T12:23:36.441678Z", - "start_time": "2018-12-21T12:23:36.437323Z" + "end_time": "2018-12-21T13:21:09.175545Z", + "start_time": "2018-12-21T13:21:09.023357Z" } }, "outputs": [], @@ -19,11 +19,11 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T12:23:36.859856Z", - "start_time": "2018-12-21T12:23:36.854044Z" + "end_time": "2018-12-21T13:21:09.356673Z", + "start_time": "2018-12-21T13:21:09.353373Z" } }, "outputs": [], @@ -182,11 +182,11 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 36, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T11:13:45.804563Z", - "start_time": "2018-12-21T11:13:45.796727Z" + "end_time": "2018-12-21T12:47:30.613234Z", + "start_time": "2018-12-21T12:47:30.601332Z" } }, "outputs": [ @@ -196,7 +196,7 @@ "" ] }, - "execution_count": 12, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -207,11 +207,11 @@ }, { "cell_type": "code", - "execution_count": 232, + "execution_count": 37, "metadata": { "ExecuteTime": { - "end_time": "2018-12-20T21:02:31.633135Z", - "start_time": "2018-12-20T21:02:31.623393Z" + "end_time": "2018-12-21T12:47:30.982302Z", + "start_time": "2018-12-21T12:47:30.974209Z" } }, "outputs": [], @@ -221,11 +221,11 @@ }, { "cell_type": "code", - "execution_count": 240, + "execution_count": 38, "metadata": { "ExecuteTime": { - "end_time": "2018-12-20T21:02:45.100192Z", - "start_time": "2018-12-20T21:02:45.092141Z" + "end_time": "2018-12-21T12:47:31.286102Z", + "start_time": "2018-12-21T12:47:31.277830Z" } }, "outputs": [], @@ -235,11 +235,11 @@ }, { "cell_type": "code", - "execution_count": 241, + "execution_count": 39, "metadata": { "ExecuteTime": { - "end_time": "2018-12-20T21:02:45.266057Z", - "start_time": "2018-12-20T21:02:45.243905Z" + "end_time": "2018-12-21T12:47:31.653479Z", + "start_time": "2018-12-21T12:47:31.601616Z" } }, "outputs": [ @@ -249,7 +249,7 @@ "True" ] }, - "execution_count": 241, + "execution_count": 39, "metadata": {}, "output_type": "execute_result" } @@ -402,11 +402,11 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 49, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T12:24:05.082438Z", - "start_time": "2018-12-21T12:24:05.072962Z" + "end_time": "2018-12-21T13:08:25.404585Z", + "start_time": "2018-12-21T13:08:25.392585Z" } }, "outputs": [ @@ -416,7 +416,7 @@ "" ] }, - "execution_count": 28, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } @@ -427,11 +427,11 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 50, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T12:24:05.821293Z", - "start_time": "2018-12-21T12:24:05.814344Z" + "end_time": "2018-12-21T13:08:25.605760Z", + "start_time": "2018-12-21T13:08:25.593455Z" } }, "outputs": [], @@ -441,11 +441,11 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 51, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T12:24:56.153283Z", - "start_time": "2018-12-21T12:24:56.148110Z" + "end_time": "2018-12-21T13:08:25.849615Z", + "start_time": "2018-12-21T13:08:25.845081Z" } }, "outputs": [], @@ -455,28 +455,301 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 52, "metadata": { "ExecuteTime": { - "end_time": "2018-12-21T12:25:42.213473Z", - "start_time": "2018-12-21T12:25:39.046109Z" + "end_time": "2018-12-21T13:08:29.582257Z", + "start_time": "2018-12-21T13:08:26.211196Z" + } + }, + "outputs": [], + "source": [ + "for _ in range(200):\n", + " env.render()\n", + " a1_discrete = random.randint(0,7)\n", + " a1 = env.discrete_to_continous_action(a1_discrete)\n", + " a2 = [0,0.,0] \n", + " obs, r, d, info = env.step(np.hstack([a1,a2])) \n", + " obs_agent2 = env.obs_agent_two()\n", + " if d: break" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-20T20:37:41.013424Z", + "start_time": "2018-12-20T20:37:41.009298Z" + } + }, + "source": [ + "# Hand-crafted Opponent" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-21T13:36:52.918915Z", + "start_time": "2018-12-21T13:36:52.900424Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "reload(lh)" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-21T13:36:53.074471Z", + "start_time": "2018-12-21T13:36:53.066921Z" + } + }, + "outputs": [], + "source": [ + "env = lh.LaserHockeyEnv(mode=lh.LaserHockeyEnv.TRAIN_DEFENCE)" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-21T13:36:53.235153Z", + "start_time": "2018-12-21T13:36:53.225682Z" + } + }, + "outputs": [], + "source": [ + "o = env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-21T13:36:53.448943Z", + "start_time": "2018-12-21T13:36:53.381148Z" } }, "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "env.render()" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-21T13:37:26.317701Z", + "start_time": "2018-12-21T13:37:26.312334Z" + } + }, + "outputs": [], + "source": [ + "player1 = lh.BasicOpponent()\n", + "player2 = lh.BasicOpponent()" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": { + "ExecuteTime": { + "end_time": "2018-12-21T13:38:57.583331Z", + "start_time": "2018-12-21T13:38:55.355794Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[-5.47181129 3.43619919 0. ] [547.18112946 343.61991882 0. ] [False False True]\n", + "[-5.45844746 -2.10290718 0. ] [8.29322179 3.10080166 0. ] [False False True]\n", + "[-5.43198776 -2.10263968 0. ] [ 4.13708146 625.11139188 0. ] [False False True]\n", + "[-5.39269352 -2.08901405 0. ] [2.75880376 3.1119949 0. ] [False False True]\n", + "[-5.34082127 -2.06229734 0. ] [2.06719421 1.55547131 0. ] [False False True]\n", + "[-5.27662277 -2.02275133 0. ] [1.64898661 1.0281868 0. ] [False False True]\n", + "[-5.20034504 -1.97063255 0. ] [1.36710808 0.75912236 0. ] [False False True]\n", + "[-5.11222935 -1.90619278 0. ] [1.16297895 0.59346001 0. ] [False False True]\n", + "[-5.01251221 -1.82967806 0. ] [1.00736371 0.47950801 0. ] [False False True]\n", + "[-4.90142536 -1.74133015 0. ] [0.88404197 0.39509175 0. ] [False False True]\n", + "[-4.77919674 -1.64138556 0. ] [0.78329117 0.32911698 0. ] [False False True]\n", + "[-4.6460495 -1.53007603 0. ] [0.69892902 0.27541756 0. ] [False False True]\n", + "[-4.50220108 -1.40762901 0. ] [0.6268364 0.23029262 0. ] [False False True]\n", + "[-4.34786606 -1.2742672 0. ] [0.56416326 0.19138645 0. ] [False False True]\n", + "[-4.22439766 -1.16757774 0. ] [0.68540038 0.21928534 0. ] [False False True]\n", + "[-4.09003544 -1.0496583 0. ] [0.60971454 0.17833255 0. ] [False False True]\n", + "[-3.94499683 -0.92073393 0. ] [0.54474347 0.14305501 0. ] [False False True]\n", + "[-3.78949547 -0.78102446 0. ] [0.48801768 0.11196712 0. ] [False False True]\n", + "[-3.66509438 -0.66925669 0. ] [0.59018637 0.11997322 0. ] [False False True]\n", + "[-3.52981758 -0.54636049 0. ] [0.52263838 0.08905924 0. ] [False True True]\n", + "[-3.38388252 -0.41255903 0. ] [0.46438857 0.06175956 0. ] [False True True]\n", + "[-3.22750282 -0.27100229 0. ] [0.41330579 0.03834307 0. ] [False True True]\n", + "[-3.10239887 -0.15775681 0. ] [0.49676565 0.02791037 0. ] [False True True]\n", + "[-2.96643353 -0.06014013 0. ] [0.43699378 0.01234698 0. ] [False True True]\n", + "[-2.81982327 0.02216053 0. ] [0.38519597 0.00539837 0. ] [False True True]\n", + "[-2.66278172 0.08945179 0. ] [0.33955121 0.02666575 0. ] [False True True]\n", + "[-2.49551773 0.14203358 0. ] [0.2987498 0.05423025 0. ] [False True True]\n", + "[-2.3182354 0.1802001 0. ] [0.2618256 0.09492606 0. ] [False True True]\n", + "[-2.13113499 0.20423937 0. ] [0.22805035 0.17134571 0. ] [False False True]\n", + "[-1.93441296 0.21443462 0. ] [0.19686476 0.42908271 0. ] [False False True]\n", + "[-1.72826195 0.21106195 0. ] [0.16783218 1.18158894 0. ] [False False True]\n", + "[-1.56334114 2.20194674 0. ] [ 0.18981691 15.19627407 0. ] [False False True]\n", + "[-1.38835526 2.11352873 0. ] [0.15886328 2.60805759 0. ] [False False True]\n", + "[-1.20350552 2.0121398 0. ] [0.13035526 1.37576726 0. ] [False False True]\n", + "[-1.00898933 1.89804077 0. ] [0.10384998 0.90310254 0. ] [False False True]\n", + "[-0.8049984 1.77148724 0. ] [0.07900273 0.64936332 0. ] [ True False True]\n", + "[-0.64180565 1.65587807 0. ] [0.07875301 0.7580368 0. ] [ True False True]\n", + "[-0.46851349 1.52784681 0. ] [0.05413468 0.54388746 0. ] [ True False True]\n", + "[-0.29397202 1.38764381 0. ] [0.03372376 0.40558696 0. ] [ True False True]\n", + "[-0.13628578 1.23551369 0. ] [0.01730759 0.30724406 0. ] [ True False True]\n", + "[0.00488281 1.07169628 0. ] [0.00069275 0.23251073 0. ] [ True False True]\n", + "[10.20042582 -2.70615294 -0.09278866] [1.63491488 0.52187621 9.27886561] [False False False]\n", + "[10.07951899 -2.81452831 -0.15974588] [1.85082719 0.63767418 5.12938847] [False False False]\n", + "[ 9.94329519 -2.90363758 -0.21966783] [2.12979969 0.79388933 2.15411499] [False False False]\n", + "[ 9.79206915 -2.97373721 -0.09899743] [2.50635377 1.01967687 0.51924205] [False False True]\n", + "[ 9.62614555 -3.02485334 -0.04670808] [3.04588358 1.38118386 0.24574502] [False False True]\n", + "[ 9.40846272 -3.04203498 -0.0965098 ] [1647.46030206 2.4635626 0.11205556] [False False True]\n", + "[ 9.166362 -3.03556099 0.50964663] [13.60037832 5.59928852 0.65966353] [False False True]\n", + "[ 8.91140766 -3.01112677 -0.01832059] [ 6.70595021 22.02869681 0.02678825] [False False True]\n", + "[ 8.64385719 -2.9690119 0.28063243] [4.38621984 3.70229117 0.47147487] [False False True]\n", + "[ 8.36395664 -2.90950076 0.33029869] [3.21733577 2.00119444 0.65206726] [False False True]\n", + "[ 8.0719511 -2.83289026 -0.11030367] [2.5099026 1.35364526 0.26397331] [False False True]\n", + "[ 7.76809044 -2.73958968 0.53949568] [2.03347686 1.00760491 1.45060595] [False False False]\n", + "[ 7.45260544 -2.62981378 0.34690648] [1.68913558 0.78913759 1.22482729] [False False False]\n", + "[ 7.12572784 -2.50389075 -0.20320788] [1.42736582 0.63649913 1.04451699] [False False False]\n", + "[ 6.78772945 -2.36259554 -0.38185701] [1.22064784 0.52233376 1.34822759] [False False False]\n", + "[ 6.43884392 -2.20629113 -0.26641921] [1.05245528 0.43254915 0.71635289] [False False True]\n", + "[ 6.07928104 -2.03522955 0.15324004] [0.91226394 0.35915984 0.37802335] [False False True]\n", + "[ 5.70922108 -1.84953423 0.45890096] [0.79305173 0.29729117 1.44905311] [False False False]\n", + "[ 5.3288805 -1.64967594 0.3403329 ] [0.68996086 0.24386012 1.49263214] [False False False]\n", + "[ 4.97959251 -1.48781651 -0.05733803] [0.8056599 0.27501859 0.41153603] [False False True]\n", + "[ 4.61963387 -1.31158547 -0.05682752] [0.68688038 0.21970855 0.43167355] [False False True]\n", + "[ 4.24922199 -1.1215521 0.29040192] [0.58534065 0.17206339 2.31508023] [False False False]\n", + "[ 3.86854191 -0.91822058 0.07777174] [0.49707616 0.13013577 2.11580111] [False False False]\n", + "[ 3.51916714 -0.75401867 0.40470972] [ 0.56504891 0.13362745 178.0548057 ] [False False False]\n", + "[ 3.15911026 -0.57639212 0.56969218] [0.46650309 0.09299916 6.59299454] [False True False]\n", + "[ 2.78858967 -0.38649777 -0.30780803] [0.38174551 0.05732818 1.7579951 ] [False True False]\n", + "[ 2.40787811 -0.19267684 0.30470357] [0.30763254 0.02778891 3.52630571] [False True False]\n", + "[ 2.05863113 -0.03807348 0.56996318] [0.32866057 0.00686643 3.25525119] [False True False]\n", + "[ 1.69872112 0.10225636 -0.1574153 ] [0.24956208 0.02145722 0.59678535] [False True True]\n", + "[1.32841892 0.22635327 0.14688007] [0.18100685 0.05656148 0.83888143] [False True True]\n", + "[0.94790096 0.3302229 0.61829109] [0.12058804 0.10149841 3.03653962] [False False False]\n", + "[0.55736084 0.40208181 0.11860643] [0.0665758 0.15955455 0.4057714 ] [ True False True]\n", + "[-2.64727497 1.6123929 0.17996708] [0.33036327 0.83893833 0.05668026] [False False True]\n", + "[-2.78925133 1.62978029 0.241896 ] [0.38820595 1.3412462 0.07837364] [False False True]\n", + "[-2.91461849 1.63362169 0.30205128] [0.45731723 3.12690987 0.10075885] [False False True]\n", + "[-3.02374458 1.62420702 0.36043292] [ 0.54209862 10.38555865 0.12389917] [False False True]\n", + "[-3.11699009 1.60181952 0.41704094] [0.64960151 1.94952784 0.14786584] [False False True]\n", + "[-3.19470692 1.56673098 0.47187534] [0.79187393 1.06320407 0.17273933] [False False True]\n", + "[-3.2572403 1.51920462 0.52493608] [0.99134008 0.71914883 0.19861091] [False False True]\n", + "[-3.30492783 1.45949602 0.57622319] [1.29503523 0.53292821 0.22558454] [False False True]\n", + "[-3.33809853 1.38785315 0.62573665] [1.82113527 0.41400715 0.25377914] [False False True]\n", + "[-3.35707188 1.30451441 0.67347652] [2.97524266 0.32995765 0.28333137] [False False True]\n", + "[-3.36216259 1.20971346 0.71944273] [7.67990494 0.2662871 0.31439905] [False False True]\n", + "[-3.3536768 1.10367632 0.76363528] [14.03502666 0.21554429 0.34716548] [False False True]\n", + "[-3.33191109 0.98662233 0.80605423] [3.69328726 0.17350596 0.38184478] [False False True]\n", + "[-3.29715729 0.8587656 0.84669954] [2.12433116 0.1375997 0.41868858] [False False True]\n", + "[-3.24969673 0.72031498 0.88557124] [1.48453608 0.10616906 0.45799472] [False False True]\n", + "[-3.18980694 0.5714736 0.92266929] [1.13385811 0.07809904 0.50011825] [False True True]\n", + "[-3.11775589 0.41243982 0.95799369] [0.91030695 0.05261162 0.54548597] [False True True]\n", + "[-3.03380394 0.25396919 0.99154449] [0.75384632 0.03246491 0.59461552] [False True True]\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "Player 2 scored\n" + "[-2.93820858 0.11220598 1.02332163] [0.6370893 0.01603299 0.64814072] [False True True]\n", + "[-2.83121777 -0.01319361 1.05332518] [0.54575914 0.00213127 0.7068464 ] [False True True]\n", + "[-2.71307373 -0.12256813 1.08155501] [0.47168325 0.02270326 0.77171588] [False True True]\n", + "[-2.58401394 -0.21624756 1.10801125] [0.40984599 0.04677902 0.84399836] [False True True]\n", + "[-2.44426823 -0.29455662 1.13269389] [0.3569981 0.07626436 0.92530494] [False True True]\n", + "[-2.29406166 -0.3578124 1.15560281] [0.31094227 0.11479081 1.01774992] [False False False]\n", + "[-2.13361549 -0.40632486 1.17673814] [0.27013947 0.17024126 1.12416436] [False False False]\n", + "[-1.96314144 -0.44039869 1.19609988] [0.2334789 0.26354734 1.24842727] [False False False]\n", + "[-1.78285027 -0.46033096 1.2136879 ] [0.20013956 0.47474445 1.39600036] [False False False]\n", + "[ 9.66290827 -4.95561387 1.02134196] [ 1.02820338 17.5566051 1.30820172] [False False False]\n", + "[ 9.36962004 -4.10897765 1.11740035] [ 1.09695016 10.49916973 1.61464605] [False False False]\n", + "[-1.26474476 -1.0004034 1.25581038] [0.16420338 0.95139083 2.08136583] [False False False]\n", + "[-1.0990448 -1.02108335 1.2663039 ] [0.13376488 0.60117897 2.46038488] [False False False]\n", + "[-0.92346954 -1.02895546 1.27502382] [0.10590368 0.44114212 2.99304535] [False False False]\n", + "[-0.73821831 -1.02427912 1.28197014] [0.08012369 0.34676507 3.80052658] [ True False False]\n", + "[-0.54349136 -1.00730658 1.28714275] [0.05604624 0.28273539 5.17689581] [ True False False]\n", + "[-0.38779259 -1.00450993 1.29054177] [0.05000061 0.35268549 8.06838295] [ True False False]\n", + "[ 7.95341606 -3.10766235 1.75840124] [ 1.04689712 0.89837785 24.67279702] [False False False]\n", + "[ 7.71903725 -3.01444978 1.63913433] [ 1.13903724 0.74284293 94.13368005] [False False False]\n", + "[ 7.46874733 -2.9048444 1.10136762] [ 1.250441 0.62539536 10.38100611] [False False False]\n", + "[ 7.20286961 -2.77919641 1.43180253] [1.38916075 0.53242346 7.35101539] [False False False]\n", + "[ 6.92172165 -2.63785919 0.80742269] [1.56849563 0.45610223 2.84847559] [False False False]\n", + "[ 6.6256115 -2.48119244 1.44327156] [1.81209466 0.39161506 3.87830797] [False False False]\n", + "[ 6.27558441 -2.29994915 1.56739715] [10.73803347 0.34616826 0.4102247 ] [False False True]\n", + "[ 5.89789314 -2.10134621 0.7404937 ] [61.66478816 0.29270239 0.19840973] [False False True]\n", + "[ 5.50769348 -1.88845797 0.92037113] [7.22686925 0.24514217 0.25260893] [False False True]\n", + "[ 5.10527058 -1.66179545 0.62703349] [3.60732056 0.20222804 0.17639161] [False False True]\n", + "[ 4.69089909 -1.42192959 0.75051569] [2.28230993 0.16304538 0.21653036] [False False True]\n", + "[ 4.26484032 -1.16953268 0.40130605] [1.58981327 0.12692132 0.11882039] [False False True]\n", + "[ 3.82735081 -0.90544518 0.69290583] [1.16074094 0.09336106 0.21069064] [False True True]\n", + "[ 3.40386696 -0.68505125 1.02640025] [1.28940835 0.08831789 0.32074467] [False True True]\n", + "[ 2.9689209 -0.45308447 0.20074037] [0.91198425 0.05478981 0.06451827] [False True True]\n", + "[ 2.52274914 -0.21900067 0.98870748] [0.65377797 0.02617986 0.32709517] [False True True]\n", + "[ 2.06557484 -0.0021792 0.65539633] [0.46418123 0.00028942 0.22337912] [False True True]\n", + "[1.5976202 0.19546977 0.87922891] [0.31766161 0.02912862 0.30900802] [False True True]\n", + "[1.11909504 0.3686747 0.31108879] [0.19994158 0.06240282 0.11285049] [False True True]\n", + "[0.63020153 0.50225062 0.26990705] [0.10241269 0.09806823 0.10116593] [False True True]\n", + "[-0.72562599 0.76656342 0.33829334] [0.10832081 0.17619596 0.13115795] [False False True]\n", + "[-0.8447628 1.00626945 0.30262232] [0.14325536 0.27988969 0.12150577] [False False True]\n", + "[-0.94832802 1.01595116 0.25478396] [0.18554712 0.35585612 0.10607518] [False False True]\n", + "[-1.03618813 1.01103973 0.20871922] [0.23871007 0.47478498 0.09022821] [False False True]\n", + "[-1.10867405 0.99180794 0.16442811] [0.30916873 0.69919609 0.07391505] [False False True]\n", + "Player 1 scored\n" ] } ], "source": [ - "for _ in range(200):\n", + "obs = env.reset()\n", + "obs_agent2 = env.obs_agent_two()\n", + "for _ in range(600):\n", " env.render()\n", - " a1_discrete = random.randint(0,7)\n", - " a1 = env.discrete_to_continous_action(a1_discrete)\n", - " a2 = [0,0.,0] \n", + " a1 = player1.act(obs)\n", + " a2 = player2.act(obs_agent2)\n", " obs, r, d, info = env.step(np.hstack([a1,a2])) \n", " obs_agent2 = env.obs_agent_two()\n", " if d: break" diff --git a/laser_hockey_env.py b/laser_hockey_env.py index 794dccc..4ef1f52 100644 --- a/laser_hockey_env.py +++ b/laser_hockey_env.py @@ -22,7 +22,7 @@ CENTER_X = W/2 CENTER_Y = H/2 -RACKETPOLY = [(-5,20),(+5,20),(+5,-20),(-5,-20)] +RACKETPOLY = [(-5,20),(+5,20),(+5,-20),(-5,-20),(-13,-10),(-15,0),(-13,10)] FORCEMULIPLAYER = 5000 TORQUEMULTIPLAYER = 100 @@ -81,6 +81,8 @@ def __init__(self, mode = NORMAL): self.done = False self.winner = 0 + self.max_puck_speed = 20 + self.timeStep = 1.0 / FPS self.time = 0 self.max_timesteps = 500 @@ -130,12 +132,13 @@ def _destroy(self): self.world_objects = [] self.drawlist = [] - def _create_player(self, position, color): + def _create_player(self, position, color, is_player_two): player = self.world.CreateDynamicBody( position=position, angle=0.0, fixtures=fixtureDef( - shape=polygonShape(vertices=[ (x/SCALE,y/SCALE) for x,y in RACKETPOLY ]), + shape=polygonShape(vertices=[ (-x/SCALE if is_player_two else x/SCALE , y/SCALE) + for x,y in RACKETPOLY ]), density=200.0, friction=0.1, categoryBits=0x0010, @@ -321,17 +324,20 @@ def reset(self): # Create players self.player1 = self._create_player( (W / 3, H / 2), - (1,0,0) + (1,0,0), + False ) if self.mode != self.NORMAL: self.player2 = self._create_player( (2* W / 3 + r_uniform(-W/6, W/6), H/2 + r_uniform(-H/4, H/4)), - (0,0,1) + (0,0,1), + True ) else: self.player2 = self._create_player( (2* W / 3, H / 2), - (0,0,1) + (0,0,1), + True ) if self.mode == self.NORMAL: self.puck = self._create_puck( (W / 2, H / 2 + r_uniform(-H/4, H/4)), (0,0,0) ) @@ -442,6 +448,13 @@ def _get_info(self): ) + def _limit_puck_speed(self): + puck_speed = np.sqrt(self.puck.linearVelocity[0]**2 + self.puck.linearVelocity[1]**2) + if puck_speed > self.max_puck_speed: + self.puck.linearDamping = 10.0 + else: + self.puck.linearDamping = 0.05 + def discrete_to_continous_action(self, discrete_action): ''' converts discrete actions into continuous ones (for each player) The actions allow only one operation each timestep, e.g. X or Y or angle change. @@ -469,6 +482,8 @@ def step(self, action): self._apply_action_with_max_speed(self.player2, action[3:5], 10, False) self.player2.ApplyTorque(-action[5] * TORQUEMULTIPLAYER, True) + self._limit_puck_speed() + self.world.Step(self.timeStep, 6 * 30, 2 * 30) obs = self._get_obs() @@ -514,3 +529,49 @@ def close(self): if self.viewer is not None: self.viewer.close() self.viewer = None + + +class BasicOpponent(): + def __init__(self): + self.mode = 0 + + def act(self, obs, verbose=False): + p1 = np.asarray(obs[0:3]) + v1 = np.asarray(obs[3:6]) + p2 = np.asarray(obs[6:9]) + v2 = np.asarray(obs[9:12]) + puck = np.asarray(obs[12:14]) + puckv = np.asarray(obs[14:16]) + # print(p1,v1,puck,puckv) + target_pos = p1[0:2] + target_angle = p1[2] + + time_to_break = 0.1 + kp = 10 + kd = 0.5 + + # if ball flies towards our goal or very slowly away: try to catch it + if puckv[0]<1.0: + dist = np.sqrt(np.sum((p1[0:2] - puck)**2)) + # Am I behind the ball? + if p1[0] < puck[0] and abs(p1[1] - puck[1]) < 1.0: + # Go and kick + target_pos = [puck[0]+0.2, puck[1] + puckv[1]*dist*0.1] + target_angle = r_uniform(-0.5,0.5) # calc proper angle here + else: + # get behind the ball first + target_pos = [-7, puck[1]] + target_angle = 0 + else: # go in front of the goal + target_pos = [-7,0] + target_angle = 0 + + + target = np.asarray([target_pos[0],target_pos[1], target_angle]) + # use PD control to get to target + error = target - p1 + need_break = abs((error / (v1+0.01))) < [time_to_break, time_to_break, time_to_break*10] + if verbose: + print(error, abs(error / (v1+0.01)), need_break) + + return error*[kp,kp,kp/2] - v1*need_break*[kd,kd,kd]