-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathdist_agents.py
77 lines (64 loc) · 2.37 KB
/
dist_agents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import numpy.random as npr
from neuronav.agents.base_agent import BaseAgent
class DistQ(BaseAgent):
"""
Implementation of the distributional Q-Learning algorithm
found in Dabney et al., 2019.
`mirror` determines whether same learning rates are used for
positive and negative td errors.
"""
def __init__(
self,
state_size: int,
action_size: int,
lr: float = 1e-1,
gamma: float = 0.99,
beta: float = 1e4,
poltype: str = "softmax",
Q_init=None,
epsilon: float = 1e-1,
dist_cells: int = 16,
mirror: bool = False,
**kwargs
):
super().__init__(state_size, action_size, lr, gamma, poltype, beta, epsilon)
if Q_init is None:
self.Q = np.zeros((action_size, state_size, dist_cells))
elif np.isscalar(Q_init):
self.Q = Q_init * npr.randn(action_size, state_size, dist_cells)
else:
self.Q = Q_init
self.dist_cells = dist_cells
self.lrs_pos = npr.uniform(0.001, 0.02, dist_cells)
if mirror:
self.lrs_neg = self.lrs_pos
else:
self.lrs_neg = npr.uniform(0.001, 0.02, dist_cells)
def sample_action(self, state):
Qs = self.Q[:, state, npr.randint(0, self.dist_cells)]
return self.base_sample_action(Qs)
def update_q(self, current_exp, next_exp=None, prospective=False):
s = current_exp[0]
s_a = current_exp[1]
s_1 = current_exp[2]
# determines whether update is on-policy or off-policy
if next_exp is None:
s_a_1 = np.argmax(self.Q[:, s_1])
else:
s_a_1 = next_exp[1]
r = current_exp[3]
next_q = self.Q[s_a_1, s_1, npr.randint(0, self.dist_cells)]
q_error = r + self.gamma * next_q - self.Q[s_a, s]
qep = (q_error > 0.0) * 1.0
if not prospective:
# actually perform update to Q if not prospective
self.Q[s_a, s] += (self.lrs_pos * qep + self.lrs_neg * (1 - qep)) * q_error
return q_error
def _update(self, current_exp, **kwargs):
q_error = self.update_q(current_exp, **kwargs)
td_error = {"q": np.linalg.norm(q_error)}
return td_error
def get_policy(self):
Qs = self.Q[:, :, npr.randint(0, self.dist_cells)]
return self.base_get_policy(Qs)