-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodules.py
338 lines (286 loc) · 14.9 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
import torch
import torch.nn as nn
import math
from torch.nn import functional as F, init
#@torch.jit.script
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
n_channels_int = n_channels[0]
in_act = input_a+input_b
t_act = torch.tanh(in_act[:, :n_channels_int])
s_act = torch.sigmoid(in_act[:, n_channels_int:])
acts = t_act * s_act
return acts
#@torch.jit.script
def fused_add_tanh_sigmoid_multiply_with_context(input_a, input_b, input_c, n_channels):
n_channels_int = n_channels[0]
in_act = input_a+input_b+input_c
t_act = torch.tanh(in_act[:, :n_channels_int])
s_act = torch.sigmoid(in_act[:, n_channels_int:])
acts = t_act * s_act
return acts
#@torch.jit.script
def fused_res_skip(tensor, res_skip, n_channels):
n_channels_int = n_channels[0]
res = res_skip[:, :n_channels_int]
skip = res_skip[:, n_channels_int:]
return (tensor + res) * math.sqrt(0.5), skip
#@torch.jit.script
def fused_res_skip_multgate(tensor, res_skip, n_channels, multgate):
n_channels_int = n_channels[0]
res = res_skip[:, :n_channels_int]
skip = res_skip[:, n_channels_int:]
if multgate.shape != torch.Size([]): # per-channel gating (used in multgate model)
multgate = multgate.unsqueeze(-1).unsqueeze(-1) # line up with trailing dims https://pytorch.org/docs/stable/notes/broadcasting.html
return (tensor + res * multgate), skip
class Conv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation_h=1, dilation_w=1,
causal=True, use_wn_bias=True):
super(Conv2D, self).__init__()
self.causal = causal
self.use_wn_bias = use_wn_bias
self.dilation_h, self.dilation_w = dilation_h, dilation_w
if self.causal:
self.padding_h = dilation_h * (kernel_size - 1) # causal along height
else:
self.padding_h = dilation_h * (kernel_size - 1) // 2
self.padding_w = dilation_w * (kernel_size - 1) // 2 # noncausal along width
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size,
dilation=(dilation_h, dilation_w), padding=(self.padding_h, self.padding_w),
bias=use_wn_bias)
self.conv = nn.utils.weight_norm(self.conv)
nn.init.kaiming_normal_(self.conv.weight)
def forward(self, tensor):
out = self.conv(tensor)
if self.causal and self.padding_h != 0:
out = out[:, :, :-self.padding_h, :]
return out
def reverse_fast(self, tensor):
self.conv.padding = (0, self.padding_w)
out = self.conv(tensor)
return out
class ZeroConv2d(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, 1, padding=0)
init.uniform_(self.conv.weight, -1e-3, 1e-3)
init.uniform_(self.conv.bias, -1e-3, 1e-3)
self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
def forward(self, x):
out = self.conv(x)
out = out * torch.exp(self.scale * 3)
return out
class ResBlock2D(nn.Module):
def __init__(self, in_channels, out_channels, skip_channels, kernel_size,
cin_channels=None, local_conditioning=True, dilation_h=None, dilation_w=None,
causal=True):
super(ResBlock2D, self).__init__()
self.out_channels = out_channels
self.local_conditioning = local_conditioning
self.cin_channels = cin_channels
self.skip = True
assert in_channels == out_channels == skip_channels
self.filter_gate_conv = Conv2D(in_channels, 2*out_channels, kernel_size, dilation_h, dilation_w, causal=causal)
self.filter_gate_conv_c = nn.Conv2d(cin_channels, 2*out_channels, kernel_size=1)
self.filter_gate_conv_c = nn.utils.weight_norm(self.filter_gate_conv_c)
nn.init.kaiming_normal_(self.filter_gate_conv_c.weight)
self.res_skip_conv = nn.Conv2d(out_channels, 2*in_channels, kernel_size=1)
self.res_skip_conv = nn.utils.weight_norm(self.res_skip_conv)
nn.init.kaiming_normal_(self.res_skip_conv.weight)
def forward(self, tensor, c=None):
h_filter_gate = self.filter_gate_conv(tensor)
c_filter_gate = self.filter_gate_conv_c(c)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply(h_filter_gate, c_filter_gate, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip(tensor, res_skip, n_channels_tensor)
def reverse(self, tensor, c=None):
# used for reverse. c is a cached tensor
h_filter_gate = self.filter_gate_conv(tensor)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply(h_filter_gate, c, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip(tensor, res_skip, n_channels_tensor)
def reverse_fast(self, tensor, c=None):
h_filter_gate = self.filter_gate_conv.reverse_fast(tensor)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply(h_filter_gate, c, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip(tensor[:, :, -1:, :], res_skip, n_channels_tensor)
class Wavenet2D(nn.Module):
# a variant of WaveNet-like arch that operates on 2D feature for WF
def __init__(self, in_channels=1, out_channels=2, num_layers=6,
residual_channels=256, gate_channels=256, skip_channels=256,
kernel_size=3, cin_channels=80, dilation_h=None, dilation_w=None,
causal=True):
super(Wavenet2D, self).__init__()
assert dilation_h is not None and dilation_w is not None
self.residual_channels = residual_channels
self.skip = True if skip_channels is not None else False
self.front_conv = nn.Sequential(
Conv2D(in_channels, residual_channels, 1, 1, 1, causal=causal))
self.res_blocks = nn.ModuleList()
for n in range(num_layers):
self.res_blocks.append(ResBlock2D(residual_channels, gate_channels, skip_channels, kernel_size,
cin_channels=cin_channels, local_conditioning=True,
dilation_h=dilation_h[n], dilation_w=dilation_w[n],
causal=causal))
def forward(self, x, c=None):
h = self.front_conv(x)
skip = 0
for i, f in enumerate(self.res_blocks):
h, s = f(h, c)
skip += s
return skip
def reverse(self, x, c=None):
# used for reverse op. c is cached tesnor
h = self.front_conv(x) # [B, 64, 1, 13264]
skip = 0
for i, f in enumerate(self.res_blocks):
c_i = c[i]
h, s = f.reverse(h, c_i) # modification: conv_queue + previous layer's output concat , c_i + conv_queue update: conv_queue last element & previous layer's output concat
skip += s
return skip
def reverse_fast(self, x, c=None):
# input: [B, 64, 1, T]
# used for reverse op. c is cached tesnor
h = self.front_conv(x) # [B, 64, 1, 13264]
skip = 0
for i, f in enumerate(self.res_blocks):
c_i = c[i]
h_new = torch.cat((self.conv_queues[i], h), dim=2) # [B, 64, 3, T]
h, s = f.reverse_fast(h_new, c_i) # we need to change this part
self.conv_queues[i] = h_new[:, :, 1:, :] # cache the tensor to queue
skip += s
return skip
def conv_queue_init(self, x):
self.conv_queues = []
B, _, _, W = x.size()
for i in range(len(self.res_blocks)):
conv_queue = torch.zeros((B, self.residual_channels, 2, W), device=x.device)
if x.type() == 'torch.cuda.HalfTensor':
conv_queue = conv_queue.half()
self.conv_queues.append(conv_queue)
class ResBlock2DHyperMultGate(nn.Module):
def __init__(self, in_channels, out_channels, skip_channels, kernel_size,
cin_channels=None, hyper_channels=None, local_conditioning=True, dilation_h=None, dilation_w=None,
causal=True):
super(ResBlock2DHyperMultGate, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.local_conditioning = local_conditioning
self.cin_channels = cin_channels
self.hyper_channels = hyper_channels
self.skip = True
assert self.in_channels == self.out_channels == skip_channels
self.filter_gate_conv = Conv2D(in_channels, 2*out_channels, kernel_size, dilation_h, dilation_w, causal=causal)
self.filter_gate_conv_c = nn.Conv2d(cin_channels, 2*out_channels, kernel_size=1)
self.filter_gate_conv_c = nn.utils.weight_norm(self.filter_gate_conv_c)
nn.init.kaiming_normal_(self.filter_gate_conv_c.weight)
self.filter_gate_conv_h = nn.Conv2d(hyper_channels, 2 * out_channels, kernel_size=1)
self.filter_gate_conv_h = nn.utils.weight_norm(self.filter_gate_conv_h)
nn.init.kaiming_normal_(self.filter_gate_conv_h.weight)
self.res_skip_conv = nn.Conv2d(out_channels, 2*in_channels, kernel_size=1)
self.res_skip_conv = nn.utils.weight_norm(self.res_skip_conv)
nn.init.kaiming_normal_(self.res_skip_conv.weight)
def forward(self, tensor, c=None, context=None, multgate=None):
h_filter_gate = self.filter_gate_conv(tensor)
c_filter_gate = self.filter_gate_conv_c(c)
context_filter_gate = self.filter_gate_conv_h(context)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply_with_context(h_filter_gate, c_filter_gate, context_filter_gate, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip_multgate(tensor, res_skip, n_channels_tensor, multgate)
def reverse(self, tensor, c=None, context=None, multgate=None):
# used for reverse. c and context are all cached
h_filter_gate = self.filter_gate_conv(tensor)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply_with_context(h_filter_gate, c, context, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip_multgate(tensor, res_skip, n_channels_tensor, multgate)
def reverse_fast(self, tensor, c=None, context=None, multgate=None):
h_filter_gate = self.filter_gate_conv.reverse_fast(tensor)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply_with_context(h_filter_gate, c, context, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip_multgate(tensor[:, :, -1:, :], res_skip, n_channels_tensor, multgate)
def reverse_faster(self, tensor, c=None, multgate=None):
# context is already added into c
h_filter_gate = self.filter_gate_conv.reverse_fast(tensor)
n_channels_tensor = torch.IntTensor([self.out_channels])
out = fused_add_tanh_sigmoid_multiply(h_filter_gate, c, n_channels_tensor)
res_skip = self.res_skip_conv(out)
return fused_res_skip_multgate(tensor[:, :, -1:, :], res_skip, n_channels_tensor, multgate)
class Wavenet2DHyperMultGate(nn.Module):
# a variant of WaveNet-like arch that operates on 2D feature for WF
def __init__(self, in_channels=1, out_channels=2, num_layers=6,
residual_channels=256, gate_channels=256, skip_channels=256,
kernel_size=3, cin_channels=80, hyper_channels=7, dilation_h=None, dilation_w=None,
causal=True):
super(Wavenet2DHyperMultGate, self).__init__()
assert dilation_h is not None and dilation_w is not None
self.residual_channels = residual_channels
self.skip = True if skip_channels is not None else False
self.front_conv = nn.Sequential(
Conv2D(in_channels, residual_channels, 1, 1, 1, causal=causal))
self.res_blocks = nn.ModuleList()
for n in range(num_layers):
self.res_blocks.append(ResBlock2DHyperMultGate(residual_channels, gate_channels, skip_channels, kernel_size,
cin_channels=cin_channels, hyper_channels=hyper_channels,
local_conditioning=True,
dilation_h=dilation_h[n], dilation_w=dilation_w[n],
causal=causal))
def forward(self, x, c=None, context=None, multgate=None):
h = self.front_conv(x)
skip = 0
for i, f in enumerate(self.res_blocks):
multgate_i = multgate[i]
h, s = f(h, c, context, multgate_i)
skip += s
return skip
def reverse(self, x, c=None, context=None, multgate=None):
# used for reverse operation. c and context are all cached tensors
h = self.front_conv(x)
skip = 0
for i, f in enumerate(self.res_blocks):
c_i = c[i]
context_i = context[i]
multgate_i = multgate[i]
h, s = f.reverse(h, c_i, context_i, multgate_i)
skip += s
return skip
def reverse_fast(self, x, c=None, context=None, multgate=None):
# input: [B, 64, 1, T]
# used for reverse op. c is cached tesnor
h = self.front_conv(x) # [B, 64, 1, 13264]
skip = 0
for i, f in enumerate(self.res_blocks):
c_i = c[i]
context_i = context[i]
multgate_i = multgate[i]
h_new = torch.cat((self.conv_queues[i], h), dim=2) # [B, 64, 3, T]
h, s = f.reverse_fast(h_new, c_i, context_i, multgate_i) # we need to change this part
self.conv_queues[i] = h_new[:, :, 1:, :] # cache the tensor to queue
skip += s
return skip
def reverse_faster(self, x, c=None, multgate=None):
# context is already added into c
# input: [B, 64, 1, T]
# used for reverse op. c is cached tesnor
h = self.front_conv(x) # [B, 64, 1, 13264]
skip = 0
for i, f in enumerate(self.res_blocks):
c_i = c[i]
multgate_i = multgate[i]
h_new = torch.cat((self.conv_queues[i], h), dim=2) # [B, 64, 3, T]
h, s = f.reverse_faster(h_new, c_i, multgate_i) # we need to change this part
self.conv_queues[i] = h_new[:, :, 1:, :] # cache the tensor to queue
skip += s
return skip
def conv_queue_init(self, x):
self.conv_queues = []
B, _, _, W = x.size()
for i in range(len(self.res_blocks)):
conv_queue = torch.zeros((B, self.residual_channels, 2, W), device=x.device)
if x.type() == 'torch.cuda.HalfTensor':
conv_queue = conv_queue.half()
self.conv_queues.append(conv_queue)