Skip to content

Commit

Permalink
add model folder with conv and residual modules
Browse files Browse the repository at this point in the history
  • Loading branch information
daitran2k1 committed Apr 17, 2023
1 parent d740fec commit 46c135c
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 0 deletions.
Empty file added model/__init__.py
Empty file.
63 changes: 63 additions & 0 deletions model/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from torch import nn
from torch.nn import functional as F


class Conv1d(nn.Conv1d):
"""
Extend nn.Conv1d for incremental dilated convolutions
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.clear_buffer()
self._linearized_weight = None
self.register_backward_hook(self._clear_linearzied_weight)

def clear_buffer(self):
self.input_buffer = None

def _clear_linearzied_weight(self, *args):
self._linearized_weight = None

def incremental_forward(self, input):
# input: (B, T, C)
if self.training:
raise RuntimeError('incremental_forward only supports eval mode')

# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
hook(self, input)

# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
dilation = self.dilation[0]

batch_size = input.size(0) # input: batch_size x num_timestep x num_channel
if kw > 1:
input = input.data
if self.input_buffer is None:
self.input_buffer = input.new(batch_size, kw + (kw - 1) * (dilation - 1), input.size(2))
else:
# shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = self.input_buffer
if dilation > 1:
input = input[:, 0::dilation, :].contiguous()
output = F.linear(input.view(batch_size, -1), weight, self.bias)
return output.view(batch_size, 1, -1)

def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
# nn.Conv1d
if self.weight.size() == (self.out_channels, self.in_channels, kw):
weight = self.weight.transpose(1, 2).contiguous()
else:
# fairseq.modules.conv_tbc.ConvTBC
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
return self._linearized_weight
Empty file added model/mixture.py
Empty file.
166 changes: 166 additions & 0 deletions model/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import math

import torch
from torch import nn
from torch.nn import functional as F

from model import conv


def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs)
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
return nn.utils.weight_norm(m)


def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx)
m.weight.data.normal_(0, std)
return m


def ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs):
freq_axis_kernel_size = kernel_size[0]
m = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs)
m.weight.data.fill_(1.0 / freq_axis_kernel_size)
m.bias.data.zero_()
return nn.utils.weight_norm(m)


def Conv1d1x1(in_channels, out_channels, bias=True):
"""
1x1 convolution layer
"""
return Conv1d(in_channels, out_channels, kernel_size=1, padding=0,
dilation=1, bias=bias)


def _conv1x1_forward(conv, x, is_incremental):
"""
Conv1x1 forward
"""
if is_incremental:
x = conv.incremental_forward(x)
else:
x = conv(x)
return x


class ResidualConv1dGLU(nn.Module):
"""
Residual dilated conv1d + Gated linear unit
Args:
residual_channels (int): Residual input / output channels
gate_channels (int): Gated activation channels
kernel_size (int): Kernel size of convolution layers
skip_out_channels (int): Skip connection channels. If None, set to same as ``residual_channels``
cin_channels (int): Local conditioning channels. If negative value is set, local conditioning is disabled
gin_channels (int): Global conditioning channels. If negative value is set, global conditioning is disabled
dropout (float): Dropout probability
padding (int): Padding for convolution layers. If None, proper padding is computed depends on dilation and kernel_size
dilation (int): Dilation factor
"""

def __init__(self, residual_channels, gate_channels, kernel_size,
skip_out_channels=None,
cin_channels=-1, gin_channels=-1,
dropout=1 - 0.95, padding=None, dilation=1, causal=True,
bias=True, *args, **kwargs):
super(ResidualConv1dGLU, self).__init__()
self.dropout = dropout
if skip_out_channels is None:
skip_out_channels = residual_channels
if padding is None:
# no future time stamps available
if causal:
padding = (kernel_size - 1) * dilation
else:
padding = (kernel_size - 1) // 2 * dilation
self.causal = causal

self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
padding=padding, dilation=dilation,
bias=bias, *args, **kwargs)

# local conditioning
if cin_channels > 0:
self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False)
else:
self.conv1x1c = None

# global conditioning
if gin_channels > 0:
self.conv1x1g = Conv1d1x1(gin_channels, gate_channels, bias=False)
else:
self.conv1x1g = None

# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias)

def forward(self, x, c=None, g=None):
return self._forward(x, c, g, False)

def incremental_forward(self, x, c=None, g=None):
return self._forward(x, c, g, True)

def _forward(self, x, c, g, is_incremental):
"""
Forward
Args:
x (Tensor): B x C x T
c (Tensor): B x C x T, Local conditioning features
g (Tensor): B x C x T, Expanded global conditioning features
is_incremental (Bool): Whether incremental mode or not
Returns:
Tensor: output
"""
residual = x
x = F.dropout(x, p=self.dropout, training=self.training)
if is_incremental:
splitdim = -1
x = self.conv.incremental_forward(x)
else:
splitdim = 1
x = self.conv(x)
# remove future time steps
x = x[:, :, :residual.size(-1)] if self.causal else x

a, b = x.split(x.size(splitdim) // 2, dim=splitdim)

# local conditioning
if c is not None:
assert self.conv1x1c is not None
c = _conv1x1_forward(self.conv1x1c, c, is_incremental)
ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
a, b = a + ca, b + cb

# global conditioning
if g is not None:
assert self.conv1x1g is not None
g = _conv1x1_forward(self.conv1x1g, g, is_incremental)
ga, gb = g.split(g.size(splitdim) // 2, dim=splitdim)
a, b = a + ga, b + gb

x = torch.tanh(a) * torch.sigmoid(b)

# For skip connection
s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental)

# For residual connection
x = _conv1x1_forward(self.conv1x1_out, x, is_incremental)

x = (x + residual) * math.sqrt(0.5)
return x, s

def clear_buffer(self):
for c in [self.conv, self.conv1x1_out, self.conv1x1_skip,
self.conv1x1c, self.conv1x1g]:
if c is not None:
c.clear_buffer()

0 comments on commit 46c135c

Please sign in to comment.