-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathqlearning_sarsa.py
57 lines (45 loc) · 1.92 KB
/
qlearning_sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import random
C = 50
GAMMA = 0.7
EPSILON = 0.01
# https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/master/contents/3_Sarsa_maze/RL_brain.py
# https://github.com/studywolf/blog/blob/master/RL/Cat%20vs%20Mouse%20exploration/qlearn.py
class RL(object):
def __init__(self):
self.actions = [-1, 0, 1] # a list
self.c = C
self.gamma = GAMMA
self.epsilon = EPSILON
self.q_table = {}
self.N = {}
def get_q(self, state, action):
return self.q_table.get((state, action), 0.0)
def learn_q(self, state, action, reward, value):
if (state, action) not in self.N:
self.N[(state, action)] = 0
self.N[(state, action)] += 1
old_value = self.q_table.get((state, action), None)
if old_value is None:
self.q_table[(state, action)] = reward
else:
# C / (C + N(s, a))
self.q_table[(state, action)] = old_value + float(self.c) / float(
self.c + self.N[(state, action)]) * (value - old_value)
def choose_action(self, state):
if random.random() < self.epsilon:
return random.choice(self.actions)
else:
q = [self.get_q(state, a) for a in self.actions]
maxQ = max(q)
if q.count(maxQ) > 1:
best = [i for i in range(3) if q[i] == maxQ]
action = self.actions[random.choice(best)]
return action
else:
return self.actions[q.index(maxQ)]
def learn_qlearning(self, state1, action1, reward, state2):
max_q_new = max([self.get_q(state2, a) for a in self.actions])
self.learn_q(state1, action1, reward, reward + self.gamma * max_q_new)
def learn_sarsa(self, state1, action1, reward, state2, action2):
q_next = self.get_q(state2, action2)
self.learn_q(state1, action1, reward, reward + self.gamma * q_next)