|
5 | 5 | class Network(nn.Module):
|
6 | 6 | def __init__(self, input_dim):
|
7 | 7 | super(Network, self).__init__()
|
8 |
| - self.l1 = nn.Linear(input_dim, 128) |
9 |
| - self.l2 = nn.Linear(128, 128 - input_dim) |
10 |
| - self.l3 = nn.Linear(128, 128) |
11 |
| - self.l4 = nn.Linear(128, 128) |
12 |
| - self.l5 = nn.Linear(128, 1) |
13 |
| - |
14 |
| - def forward(self, x): |
15 |
| - h = F.softplus(self.l1(x), beta=100) |
16 |
| - h = F.softplus(self.l2(h), beta=100) |
17 |
| - h = torch.cat((h, x), axis=1) |
18 |
| - h = F.softplus(self.l3(h), beta=100) |
19 |
| - h = F.softplus(self.l4(h), beta=100) |
20 |
| - h = self.l5(h) |
21 |
| - return h |
22 |
| - |
23 |
| -class NetworkLarge(nn.Module): |
24 |
| - def __init__(self, input_dim): |
25 |
| - super(NetworkLarge, self).__init__() |
26 | 8 | self.l1 = nn.Linear(input_dim, 512)
|
27 | 9 | self.l2 = nn.Linear(512, 512)
|
28 | 10 | self.l3 = nn.Linear(512, 512)
|
|
0 commit comments