-
Notifications
You must be signed in to change notification settings - Fork 2
/
agent_constraints.py
37 lines (27 loc) · 1.07 KB
/
agent_constraints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
"""
This module implements several agents which have constraints over their action space
"""
from agent import IndQLearningAgent
import numpy as np
from numpy.random import choice
class RestrictedIndQLearningAgent(IndQLearningAgent):
"""we only need to override this method"""
def act(self, obs=None, valid_action=None, previous_action=None):
# mask the Q-function here
# TODO
if valid_action is not None:
mask = np.array([ valid_action(obs, a, previous_action) for a in self.action_space ])
Q = self.Q[obs, :]
#print(Q)
Q[~mask] = -np.infty
#print(Q)
if np.random.rand() < self.epsilon:
mask = mask.astype(float)
mask /= mask.sum()
a = choice(self.action_space, p=mask)
if a == 0:
print('fuck')
return choice(self.action_space, p=mask)
else:
return self.action_space[np.argmax(Q)]
return super(RestrictedIndQLearningAgent, self).act(obs)