Skip to content

Commit

Permalink
Merge pull request #4 from paul-krug/add-embeddings
Browse files Browse the repository at this point in the history
Add embeddings
  • Loading branch information
paul-krug authored Feb 9, 2024
2 parents ebc33ca + 01a38da commit 8a24d81
Show file tree
Hide file tree
Showing 3 changed files with 356 additions and 91 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ model = TCN(
kernel_initializer: str = 'xavier_uniform',
use_skip_connections: bool = False,
input_shape: str = 'NCL',
embedding_shapes: Optional[ ArrayLike ] = None,
use_gate: bool = False,
)
# Continue to train/use model for your task
```
Expand All @@ -52,3 +54,6 @@ The order of output dimensions will be the same as for the input tensors.
- `kernel_initializer`: The function used for initializing the networks weights. Currently, can be 'uniform', 'normal', 'kaiming_uniform', 'kaiming_normal', 'xavier_uniform' or 'xavier_normal'. Kaiming and xavier initialization are also known as He and Glorot initialization, respectively. While [Bai et al.](https://arxiv.org/abs/1803.01271) originally use normal initialization, this sometimes leads to divergent behaviour and usually 'xavier_uniform' is a very good starting point, so it is used as the default here.
- `use_skip_connections`: If 'True', skip connections will be present from the output of each residual block (before the sum with the resiual, similar to WaveNet) to the end of the network, where all the connections are summed. The sum then passes another activation function. If the output of a residual block has a feature dimension different from the feature dimension of the last residual block, the respective skip connection will use a 1x1 convolution for downsampling the feature dimension. This procedure is similar to the way resiudal connections around each residual block are handled. Skip connections usually help to train deeper netowrks efficiently. However, the parameter defaults to 'False', because skip connections were not used in the original paper by [Bai et al.](https://arxiv.org/abs/1803.01271)
- `ìnput_shape`: Defaults to 'NCL', which means input tensors are expected to have the shape (batch_size, feature_channels, time_steps). This corresponds to the input shape that is expected by 1D convolutions in PyTorch. However, a common convention for timeseries data is the shape (batch_size, time_steps, feature_channels). If you want to use this convention, set the parameter to 'NLC'.
- `embedding_shapes`: Accepts an Iterable that contains tuples or types that can be converted to tuples. The tuples should contain the number of embedding dimensions. Embedding can either be 1D, e.g., lets say you train a TCN to generate speech samples and you want to condition the audio generation on a speaker embedding of shape (256,). Then you would pass [(256,)] to the argument. The TCN forward function will then accept tensors of shape (batch_size, 256,) as the argument 'embedding'. The embeddings will be automatically broadcasted to the length of the input sequence and added to the input tensor right before the first activation function in each temporal block. Hence, 1D embedding shapes will lead to a global conditioning of the TCN. For local conditioning, an 'embedding_shapes' argument should be 2D including 'None' as its second dimension (time_steps). It may look like this: [(32,None)]. Then the forward function would accept tensors of shape (batch_size, 32, time_steps). If 'embedding_shapes' is set to None, no embeddings will be used.
- `embedding_mode`: Valid modes are 'add' and 'concat'. If 'add', the embeddings will be added to the input tensor before the first activation function in each temporal block. If 'concat', the embeddings will be concatenated to the input tensor along the feature dimension and then projected to the expected dimension via a 1x1 convolution. The default is 'add'.
- `use_gate`: If 'True', a gated linear unit (see [Dauphin et al.](https://arxiv.org/abs/1612.08083)) will be used as the first activation function in each temporal block. If 'False', the activation function will be the one specified by the 'activation' parameter. Gated units may be used as activation functions to feed in embeddings (see above). This may or may not lead to better results than the regular activation, but it is likely to increase the computational costs. The default is 'False'.
154 changes: 145 additions & 9 deletions pytorch_tcn/tcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Union
from typing import Optional
from numpy.typing import ArrayLike

from collections.abc import Iterable


activation_fn = dict(
Expand Down Expand Up @@ -152,15 +152,26 @@ def __init__(
use_norm,
activation,
kerner_initializer,
embedding_shapes,
embedding_mode,
use_gate,
):
super(TemporalBlock, self).__init__()
self.use_norm = use_norm
self.activation_name = activation
self.kernel_initializer = kerner_initializer
self.embedding_shapes = embedding_shapes
self.embedding_mode = embedding_mode
self.use_gate = use_gate

if self.use_gate:
conv1d_n_outputs = 2 * n_outputs
else:
conv1d_n_outputs = n_outputs

self.conv1 = TemporalConv1d(
in_channels=n_inputs,
out_channels=n_outputs,
out_channels=conv1d_n_outputs,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
Expand All @@ -177,10 +188,16 @@ def __init__(
)

if use_norm == 'batch_norm':
self.norm1 = nn.BatchNorm1d(n_outputs)
if self.use_gate:
self.norm1 = nn.BatchNorm1d(2 * n_outputs)
else:
self.norm1 = nn.BatchNorm1d(n_outputs)
self.norm2 = nn.BatchNorm1d(n_outputs)
elif use_norm == 'layer_norm':
self.norm1 = nn.LayerNorm(n_outputs)
if self.use_gate:
self.norm1 = nn.LayerNorm(2 * n_outputs)
else:
self.norm1 = nn.LayerNorm(n_outputs)
self.norm2 = nn.LayerNorm(n_outputs)
elif use_norm == 'weight_norm':
self.norm1 = None
Expand All @@ -199,6 +216,26 @@ def __init__(
self.dropout2 = nn.Dropout(dropout)

self.downsample = nn.Conv1d(n_inputs, n_outputs, 1) if n_inputs != n_outputs else None

if self.embedding_shapes is not None:
if self.use_gate:
embedding_layer_n_outputs = 2 * n_outputs
else:
embedding_layer_n_outputs = n_outputs

self.embedding_projection_1 = nn.Conv1d(
in_channels = sum( [ shape[0] for shape in self.embedding_shapes ] ),
out_channels = embedding_layer_n_outputs,
kernel_size = 1,
)

self.embedding_projection_2 = nn.Conv1d(
in_channels = 2 * embedding_layer_n_outputs,
out_channels = embedding_layer_n_outputs,
kernel_size = 1,
)

self.glu = nn.GLU(dim=1)

self.init_weights()
return
Expand Down Expand Up @@ -235,11 +272,66 @@ def apply_norm(
x = norm_fn( x.transpose(1, 2) )
x = x.transpose(1, 2)
return x

def apply_embeddings(
self,
x,
embeddings,
):

if not isinstance( embeddings, list ):
embeddings = [ embeddings ]

e = []
for embedding, expected_shape in zip( embeddings, self.embedding_shapes ):
if embedding.shape[1] != expected_shape[0]:
raise ValueError(
f"""
Embedding shape {embedding.shape} passed to 'forward' does not
match the expected shape {expected_shape} provided as input argument
'embedding_shapes'.
"""
)
if len( embedding.shape ) == 2:
# unsqueeze time dimension of e and repeat it to match x
e.append( embedding.unsqueeze(2).repeat(1, 1, x.shape[2]) )
elif len( embedding.shape ) == 3:
# check if time dimension of embedding matches x
if embedding.shape[2] != x.shape[2]:
raise ValueError(
f"""
Embedding time dimension {embedding.shape[2]} does not match
input time dimension {x.shape[2]}
"""
)
e.append( embedding )
e = torch.cat( e, dim=1 )
e = self.embedding_projection_1( e )
#print('shapes:', e.shape, x.shape)
if self.embedding_mode == 'concat':
x = self.embedding_projection_2(
torch.cat( [ x, e ], dim=1 )
)
elif self.embedding_mode == 'add':
x = x + e

def forward(self, x):
return x

def forward(
self,
x,
embeddings,
):
out = self.conv1(x)
out = self.apply_norm( self.norm1, out )
out = self.activation1(out)

if embeddings is not None:
out = self.apply_embeddings( out, embeddings )

if self.use_gate:
out = self.glu(out)
else:
out = self.activation1(out)
out = self.dropout1(out)

out = self.conv2(out)
Expand Down Expand Up @@ -267,6 +359,9 @@ def __init__(
kernel_initializer: str = 'xavier_uniform',
use_skip_connections: bool = False,
input_shape: str = 'NCL',
embedding_shapes: Optional[ ArrayLike ] = None,
embedding_mode: str = 'add',
use_gate: bool = False,
):
super(TCN, self).__init__()
if dilations is not None and len(dilations) != len(num_channels):
Expand Down Expand Up @@ -310,6 +405,39 @@ def __init__(
self.kernel_initializer = kernel_initializer
self.use_skip_connections = use_skip_connections
self.input_shape = input_shape
self.embedding_shapes = embedding_shapes
self.use_gate = use_gate

if embedding_shapes is not None:
if isinstance(embedding_shapes, Iterable):
for shape in embedding_shapes:
if not isinstance( shape, tuple ):
try:
shape = tuple( shape )
except Exception as e:
raise ValueError(
f"Each shape in argument 'embedding_shapes' must be an Iterable of tuples. "
f"Tried to convert {shape} to tuple, but failed with error: {e}"
)
if len( shape ) not in [ 1, 2 ]:
raise ValueError(
f"""
Tuples in argument 'embedding_shapes' must be of length 1 or 2.
One-dimensional tuples are interpreted as (embedding_dim,) and
two-dimensional tuples as (embedding_dim, time_steps).
"""
)
else:
raise ValueError(
f"Argument 'embedding_shapes' must be a list of tuples, "
f"but is {type(embedding_shapes)}"
)

if embedding_mode not in [ 'add', 'concat' ]:
raise ValueError(
f"Argument 'embedding_mode' must be one of: ['add', 'concat']"
)
self.embedding_mode = embedding_mode

if use_skip_connections:
self.downsample_skip_connection = nn.ModuleList()
Expand Down Expand Up @@ -347,6 +475,9 @@ def __init__(
use_norm=use_norm,
activation=activation,
kerner_initializer=self.kernel_initializer,
embedding_shapes=self.embedding_shapes,
embedding_mode=self.embedding_mode,
use_gate=self.use_gate,
)
]

Expand All @@ -366,15 +497,19 @@ def init_skip_connection_weights(self):
)
return

def forward(self, x):
def forward(
self,
x,
embeddings=None,
):
if self.input_shape == 'NLC':
x = x.transpose(1, 2)
if self.use_skip_connections:
skip_connections = []
# Adding skip connections from each layer to the output
# Excluding the last layer, as it would not skip trainable weights
for index, layer in enumerate( self.network ):
x, skip_out = layer(x)
x, skip_out = layer(x, embeddings )
if self.downsample_skip_connection[ index ] is not None:
skip_out = self.downsample_skip_connection[ index ]( skip_out )
if index < len( self.network ) - 1:
Expand All @@ -384,7 +519,8 @@ def forward(self, x):
x = self.activation_out( x )
else:
for layer in self.network:
x, _ = layer(x)
#print( 'TCN, embeddings:', embeddings.shape )
x, _ = layer( x, embeddings )
if self.input_shape == 'NLC':
x = x.transpose(1, 2)
return x
Loading

0 comments on commit 8a24d81

Please sign in to comment.