-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathmodel.py
36 lines (31 loc) · 1.18 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.optim as optim
from neuronav.deep_agents.modules import gen_encoder
class PPOModel(torch.nn.Module):
def __init__(self, model_params):
super().__init__()
self.enc_type = model_params["enc_type"]
self.obs_size = model_params["obs_size"]
self.act_size = model_params["act_size"]
self.h_size = model_params["h_size"]
self.depth = model_params["depth"]
self.encoder = gen_encoder(
self.obs_size, self.h_size, self.depth, self.enc_type
)
self.policy = nn.Linear(self.h_size, self.act_size)
self.value = nn.Linear(self.h_size, 1)
self.optimizer = optim.AdamW(self.parameters(), lr=model_params["lr"])
def forward(self, x):
h = self.encode(x)
logits = self.policy(h)
value = self.value(h)
return logits, value.view(-1)
def encode(self, x):
x = x.view(-1, self.obs_size)
h = self.encoder(x)
return h
def sample_action(self, obs):
logits, value = self.forward(obs)
action = torch.distributions.Categorical(logits=logits).sample()
return action, logits, value.view(-1)