Skip to content

Commit 9f81cb5

Browse files
committed
fix
1 parent d67fcbe commit 9f81cb5

File tree

2 files changed

+1
-19
lines changed

2 files changed

+1
-19
lines changed

network.py

-18
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,6 @@
55
class Network(nn.Module):
66
def __init__(self, input_dim):
77
super(Network, self).__init__()
8-
self.l1 = nn.Linear(input_dim, 128)
9-
self.l2 = nn.Linear(128, 128 - input_dim)
10-
self.l3 = nn.Linear(128, 128)
11-
self.l4 = nn.Linear(128, 128)
12-
self.l5 = nn.Linear(128, 1)
13-
14-
def forward(self, x):
15-
h = F.softplus(self.l1(x), beta=100)
16-
h = F.softplus(self.l2(h), beta=100)
17-
h = torch.cat((h, x), axis=1)
18-
h = F.softplus(self.l3(h), beta=100)
19-
h = F.softplus(self.l4(h), beta=100)
20-
h = self.l5(h)
21-
return h
22-
23-
class NetworkLarge(nn.Module):
24-
def __init__(self, input_dim):
25-
super(NetworkLarge, self).__init__()
268
self.l1 = nn.Linear(input_dim, 512)
279
self.l2 = nn.Linear(512, 512)
2810
self.l3 = nn.Linear(512, 512)

utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.autograd as autograd
55
import torch.nn as nn
66

7-
from network import NetworkLarge as Network
7+
from network import Network
88

99
def sample_fake(pts, noise=0.3):
1010
sampled = pts + torch.normal(0, 1, pts.shape) * noise.unsqueeze(1)

0 commit comments

Comments
 (0)