forked from cychai1995/DDPGfD
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
33 lines (27 loc) · 1.03 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch.nn as nn
class ActorNet(nn.Module):
def __init__(self, in_dim, out_dim, device):
super(ActorNet, self).__init__()
self.device = device
self.net = nn.Sequential(nn.Linear(in_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(),
nn.Linear(512, out_dim), nn.Tanh()) # +-1 output
def forward(self, state):
"""
:param state: N, in_dim
:return: Action (deterministic), N,out_dim
"""
action = self.net(state)
return action
class CriticNet(nn.Module):
def __init__(self, s_dim, a_dim, device):
super(CriticNet, self).__init__()
self.device = device
in_dim = s_dim + a_dim
self.net = nn.Sequential(nn.Linear(in_dim, 512), nn.ReLU(), nn.Linear(512, 512), nn.ReLU(),
nn.Linear(512, 1))
def forward(self, sa_pairs):
"""
:param sa_pairs: state-action pairs, (N, in_dim)
:return: Q-values , N,1
"""
return self.net(sa_pairs)