-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add model folder with conv and residual modules
- Loading branch information
1 parent
d740fec
commit 46c135c
Showing
4 changed files
with
229 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
|
||
class Conv1d(nn.Conv1d): | ||
""" | ||
Extend nn.Conv1d for incremental dilated convolutions | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.clear_buffer() | ||
self._linearized_weight = None | ||
self.register_backward_hook(self._clear_linearzied_weight) | ||
|
||
def clear_buffer(self): | ||
self.input_buffer = None | ||
|
||
def _clear_linearzied_weight(self, *args): | ||
self._linearized_weight = None | ||
|
||
def incremental_forward(self, input): | ||
# input: (B, T, C) | ||
if self.training: | ||
raise RuntimeError('incremental_forward only supports eval mode') | ||
|
||
# run forward pre hooks (e.g., weight norm) | ||
for hook in self._forward_pre_hooks.values(): | ||
hook(self, input) | ||
|
||
# reshape weight | ||
weight = self._get_linearized_weight() | ||
kw = self.kernel_size[0] | ||
dilation = self.dilation[0] | ||
|
||
batch_size = input.size(0) # input: batch_size x num_timestep x num_channel | ||
if kw > 1: | ||
input = input.data | ||
if self.input_buffer is None: | ||
self.input_buffer = input.new(batch_size, kw + (kw - 1) * (dilation - 1), input.size(2)) | ||
else: | ||
# shift buffer | ||
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() | ||
# append next input | ||
self.input_buffer[:, -1, :] = input[:, -1, :] | ||
input = self.input_buffer | ||
if dilation > 1: | ||
input = input[:, 0::dilation, :].contiguous() | ||
output = F.linear(input.view(batch_size, -1), weight, self.bias) | ||
return output.view(batch_size, 1, -1) | ||
|
||
def _get_linearized_weight(self): | ||
if self._linearized_weight is None: | ||
kw = self.kernel_size[0] | ||
# nn.Conv1d | ||
if self.weight.size() == (self.out_channels, self.in_channels, kw): | ||
weight = self.weight.transpose(1, 2).contiguous() | ||
else: | ||
# fairseq.modules.conv_tbc.ConvTBC | ||
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() | ||
assert weight.size() == (self.out_channels, kw, self.in_channels) | ||
self._linearized_weight = weight.view(self.out_channels, -1) | ||
return self._linearized_weight |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
import math | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn import functional as F | ||
|
||
from model import conv | ||
|
||
|
||
def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): | ||
m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs) | ||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu") | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
return nn.utils.weight_norm(m) | ||
|
||
|
||
def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01): | ||
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx) | ||
m.weight.data.normal_(0, std) | ||
return m | ||
|
||
|
||
def ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs): | ||
freq_axis_kernel_size = kernel_size[0] | ||
m = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs) | ||
m.weight.data.fill_(1.0 / freq_axis_kernel_size) | ||
m.bias.data.zero_() | ||
return nn.utils.weight_norm(m) | ||
|
||
|
||
def Conv1d1x1(in_channels, out_channels, bias=True): | ||
""" | ||
1x1 convolution layer | ||
""" | ||
return Conv1d(in_channels, out_channels, kernel_size=1, padding=0, | ||
dilation=1, bias=bias) | ||
|
||
|
||
def _conv1x1_forward(conv, x, is_incremental): | ||
""" | ||
Conv1x1 forward | ||
""" | ||
if is_incremental: | ||
x = conv.incremental_forward(x) | ||
else: | ||
x = conv(x) | ||
return x | ||
|
||
|
||
class ResidualConv1dGLU(nn.Module): | ||
""" | ||
Residual dilated conv1d + Gated linear unit | ||
Args: | ||
residual_channels (int): Residual input / output channels | ||
gate_channels (int): Gated activation channels | ||
kernel_size (int): Kernel size of convolution layers | ||
skip_out_channels (int): Skip connection channels. If None, set to same as ``residual_channels`` | ||
cin_channels (int): Local conditioning channels. If negative value is set, local conditioning is disabled | ||
gin_channels (int): Global conditioning channels. If negative value is set, global conditioning is disabled | ||
dropout (float): Dropout probability | ||
padding (int): Padding for convolution layers. If None, proper padding is computed depends on dilation and kernel_size | ||
dilation (int): Dilation factor | ||
""" | ||
|
||
def __init__(self, residual_channels, gate_channels, kernel_size, | ||
skip_out_channels=None, | ||
cin_channels=-1, gin_channels=-1, | ||
dropout=1 - 0.95, padding=None, dilation=1, causal=True, | ||
bias=True, *args, **kwargs): | ||
super(ResidualConv1dGLU, self).__init__() | ||
self.dropout = dropout | ||
if skip_out_channels is None: | ||
skip_out_channels = residual_channels | ||
if padding is None: | ||
# no future time stamps available | ||
if causal: | ||
padding = (kernel_size - 1) * dilation | ||
else: | ||
padding = (kernel_size - 1) // 2 * dilation | ||
self.causal = causal | ||
|
||
self.conv = Conv1d(residual_channels, gate_channels, kernel_size, | ||
padding=padding, dilation=dilation, | ||
bias=bias, *args, **kwargs) | ||
|
||
# local conditioning | ||
if cin_channels > 0: | ||
self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) | ||
else: | ||
self.conv1x1c = None | ||
|
||
# global conditioning | ||
if gin_channels > 0: | ||
self.conv1x1g = Conv1d1x1(gin_channels, gate_channels, bias=False) | ||
else: | ||
self.conv1x1g = None | ||
|
||
# conv output is split into two groups | ||
gate_out_channels = gate_channels // 2 | ||
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) | ||
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) | ||
|
||
def forward(self, x, c=None, g=None): | ||
return self._forward(x, c, g, False) | ||
|
||
def incremental_forward(self, x, c=None, g=None): | ||
return self._forward(x, c, g, True) | ||
|
||
def _forward(self, x, c, g, is_incremental): | ||
""" | ||
Forward | ||
Args: | ||
x (Tensor): B x C x T | ||
c (Tensor): B x C x T, Local conditioning features | ||
g (Tensor): B x C x T, Expanded global conditioning features | ||
is_incremental (Bool): Whether incremental mode or not | ||
Returns: | ||
Tensor: output | ||
""" | ||
residual = x | ||
x = F.dropout(x, p=self.dropout, training=self.training) | ||
if is_incremental: | ||
splitdim = -1 | ||
x = self.conv.incremental_forward(x) | ||
else: | ||
splitdim = 1 | ||
x = self.conv(x) | ||
# remove future time steps | ||
x = x[:, :, :residual.size(-1)] if self.causal else x | ||
|
||
a, b = x.split(x.size(splitdim) // 2, dim=splitdim) | ||
|
||
# local conditioning | ||
if c is not None: | ||
assert self.conv1x1c is not None | ||
c = _conv1x1_forward(self.conv1x1c, c, is_incremental) | ||
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) | ||
a, b = a + ca, b + cb | ||
|
||
# global conditioning | ||
if g is not None: | ||
assert self.conv1x1g is not None | ||
g = _conv1x1_forward(self.conv1x1g, g, is_incremental) | ||
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim) | ||
a, b = a + ga, b + gb | ||
|
||
x = torch.tanh(a) * torch.sigmoid(b) | ||
|
||
# For skip connection | ||
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) | ||
|
||
# For residual connection | ||
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) | ||
|
||
x = (x + residual) * math.sqrt(0.5) | ||
return x, s | ||
|
||
def clear_buffer(self): | ||
for c in [self.conv, self.conv1x1_out, self.conv1x1_skip, | ||
self.conv1x1c, self.conv1x1g]: | ||
if c is not None: | ||
c.clear_buffer() |