-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrtdp_agent.py
120 lines (110 loc) · 4.86 KB
/
rtdp_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from Environment import Environment
import numpy as np
num_samples = 5
gamma = 1
class Agent:
def __init__(self) :
self.value_dict = {}
self.need_update = []
def getValue(self, state, env) :
value = self.value_dict.get(state, self.getHeuristic(state, env))
return value
def updateValue(self, state, value) :
self.value_dict[state] = value
def getAction(self, state, env) :
#print (self.getValue(self.parseState(state, env.num_cars), env))
Q = [0 for i in range(env.num_actions)]
#self.need_update.append(state)
for i in range(env.num_actions) :
cum_reward = 0
for j in range(num_samples) :
next_state, reward = env.checkNextState(i)
cum_reward += gamma*self.getValue(self.parseState(next_state,env.num_cars ), env) + reward # TODO : NEXT STATE MIGHT NOT
Q[i] = round(float(cum_reward)/num_samples,5)
maxQ = np.max(np.array(Q))
listQ = []
for ind in range(len(Q)) :
if Q[ind] == maxQ :
listQ.append(ind)
action = np.random.randint(0, len(listQ), 1)
self.updateValue(self.parseState(state,env.num_cars ), Q[listQ[action[0]]])
#print (self.getValue(self.parseState(state, env.num_cars), env), Q)
#print (self.parseState(state, env.num_cars))
return listQ[action[0]]
def updateEndEpisode(self, env):
for indddddd in range(5) :
for state in reversed(self.need_update) :
env.setState(state)
Q = [0 for i in range(env.num_actions)]
for i in range(env.num_actions):
cum_reward = 0
for j in range(max(num_samples/5, 1)):
next_state, reward = env.checkNextState(i)
cum_reward += gamma * self.getValue(
self.parseState(next_state, env.num_cars), env) + reward # TODO : NEXT STATE MIGHT NOT
Q[i] = round(float(cum_reward) / num_samples, 5)
maxQ = np.max(np.array(Q))
listQ = []
for ind in range(len(Q)):
if Q[ind] == maxQ:
listQ.append(ind)
action = np.random.randint(0, len(listQ), 1)
self.updateValue(self.parseState(state, env.num_cars), Q[listQ[action[0]]])
self.need_update = []
def parseState(self, state, num_cars):
pos = state[0]
vel = state[1]
car_ori = state[2]
parsedState = [0 for i in range( 4 +2 * 4 + 1)]
[car_x, car_y] = pos[0]
parsedState[0],parsedState[1] = car_x,car_y
parsedState[2], parsedState[3] = vel[0], car_ori
parsedState[3] = round(round((float(parsedState[3])/45) * 5,0) *9 ,0)
min1 = float("inf")
min2 = float("inf")
min3 = float("inf")
min4 = float("inf")
min_dist = [-1,-1,-1,-1]
tile_size = 1
dist_max = 400
for i in range(1,num_cars) :
dist = (pos[i][0]-car_x)**2 + (pos[i][1]-car_y)**2
if dist < min1 :
min1 = dist
min_dist = [i, min_dist[0], min_dist[1], min_dist[3]]
elif dist < min2 :
min2 = dist
min_dist = [min_dist[0], i,min_dist[1], min_dist[3]]
elif dist < min3 :
min3 = dist
min_dist = [min_dist[0], min_dist[1],i, min_dist[3]]
elif dist < min4 :
min4 = dist
min_dist = [min_dist[0], min_dist[1], min_dist[3],i]
for i in range(4) :
if np.sum((np.array(pos[min_dist[i]]) - np.array(pos[0])) ** 2) < dist_max :
[parsedState[4+2*i], parsedState[4+2*i +1]] = list(np.array(pos[min_dist[i]]) - np.array(pos[0]))
else :
[parsedState[4 + 2 * i], parsedState[4 + 2 * i + 1]] = [-1,-1]
for i in min_dist :
if pos[i][0] > car_x :
parsedState[len(parsedState) - 1] = (pos[i][0]-car_x)**2 + (pos[i][1]-car_y)**2
break
if parsedState[len(parsedState) - 1] == 0:
min = float("inf")
for i in range(1,num_cars) :
if pos[i][0] > car_x:
dist = (pos[i][0]-car_x)**2 + (pos[i][1]-car_y)**2
if dist< min :
parsedState[len(parsedState) - 1] = dist
min = dist
if parsedState[len(parsedState) - 1] == 0:
parsedState[len(parsedState) - 1] = -1
return tuple([round(round(i,0)/tile_size,0) for i in parsedState])
def getHeuristic(self, state, env):
# if env.checkCollision([state[0],state[1]], state[3]) :
# return -500
# if state[0] >= 200 :
# return 0
# return round(float(state[0] - 200)/0.5, 3)
return 0