-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaction_selection.py
32 lines (27 loc) · 1.07 KB
/
action_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import random
def epsilon_greedy(q_network, env, state, epsilon, device):
"""
Implements epsilon-greedy action selection
Params:
q_network: deep Q-learning network
env: RL environment
state: current state of env
epsilon: float, exploration probability
device: cpu or gpu
"""
if random.random() < epsilon:
return torch.tensor(
[[env.action_space.sample()]], device=device, dtype=torch.long
)
else:
# q_network(state) returns a tensor of shape (1, n_actions),
# i.e. [[1, 2, ..., n_actions]]
# tensor.max(1) returns a tuple (values, indices) where
# values is a tensor containing the maximum value in each row
# and indices is a tensor containing the column index of the maximum
# tensor.max(1).indices is one dimensional and has length 1, e.g. [idx]
# calling view (1, 1) resizes the index tensor to [[idx]]
return q_network(state).max(1).indices.view(1, 1)
def greedy(q_network, state):
return q_network(state).max(1).indices.view(1, 1)