-
Notifications
You must be signed in to change notification settings - Fork 0
/
INR.py
53 lines (39 loc) · 1.38 KB
/
INR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.nn as nn
import torchvision
import numpy as np
class Swish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x * torch.sigmoid(x)
class SirenLayer(nn.Module):
def __init__(self, in_f, out_f, w0=30, is_first=False, is_last=False):
super().__init__()
self.in_f = in_f
self.w0 = w0
self.linear = nn.Linear(in_f, out_f)
self.is_first = is_first
self.is_last = is_last
self.init_weights()
def init_weights(self):
b = 1 / \
self.in_f if self.is_first else np.sqrt(6 / self.in_f) / self.w0
with torch.no_grad():
self.linear.weight.uniform_(-b, b)
def forward(self, x):
x = self.linear(x)
# return x if self.is_last else torch.nn.functional.relu(x)
return x if self.is_last else torch.sin(self.w0 * x)
def input_mapping(x, B):
if B is None:
return x
else:
x_proj = (2. * np.pi * x) @ B.t()
return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
def gon_model(num_layers, input_dim, hidden_dim):
layers = [SirenLayer(input_dim, hidden_dim, is_first=True)]
for i in range(1, num_layers - 1):
layers.append(SirenLayer(hidden_dim, hidden_dim))
layers.append(SirenLayer(hidden_dim, 3, is_last=True))
return nn.Sequential(*layers)