diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 142ef98f30..b47db51cf5 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -1,7 +1,8 @@ import math -from einops import pack +from einops import pack, rearrange import torch from torch import nn +import conformer class PositionalEncoding(torch.nn.Module): @@ -71,6 +72,40 @@ def __init__(self, channels): def forward(self, x): return self.conv(x) + + +class ConformerBlock(conformer.ConformerBlock): + def __init__( + self, + dim: int, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + conv_expansion_factor: int = 2, + conv_kernel_size: int = 31, + attn_dropout: float = 0., + ff_dropout: float = 0., + conv_dropout: float = 0., + conv_causal: bool = False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward(self, x, mask,): + x = rearrange(x, "b c t -> b t c") + mask = rearrange(mask, "b 1 t -> b t") + output = super().forward(x=x, mask=mask.bool()) + return rearrange(output, "b t c -> b c t") class UNet(nn.Module): @@ -80,6 +115,12 @@ def __init__( model_channels: int, out_channels: int, num_blocks: int, + transformer_num_heads: int = 4, + transformer_dim_head: int = 64, + transformer_ff_mult: int = 1, + transformer_conv_expansion_factor: int = 2, + transformer_conv_kernel_size: int = 31, + transformer_dropout: float = 0.05, ): super().__init__() self.in_channels = in_channels @@ -107,6 +148,18 @@ def __init__( ) ) + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + if level != num_blocks - 1: block.append(Downsample1D(block_out_channels)) else: @@ -116,6 +169,30 @@ def __init__( self.input_blocks.append(block) self.middle_blocks = nn.ModuleList([]) + for i in range(2): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_out_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + + self.middle_blocks.append(block) self.output_blocks = nn.ModuleList([]) block_in_channels = block_out_channels * 2 @@ -131,6 +208,18 @@ def __init__( ) ) + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + if level != num_blocks - 1: block.append(Upsample1D(block_out_channels)) else: @@ -142,6 +231,29 @@ def __init__( self.conv_block = ConvBlock1D(model_channels, model_channels) self.conv = nn.Conv1d(model_channels, self.out_channels, 1) + def _create_transformer_block( + self, + dim, + dim_head: int = 64, + num_heads: int = 4, + ff_mult: int = 1, + conv_expansion_factor: int = 2, + conv_kernel_size: int = 31, + dropout: float = 0.05, + ): + return ConformerBlock( + dim=dim, + dim_head=dim_head, + heads=num_heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=dropout, + ff_dropout=dropout, + conv_dropout=dropout, + conv_causal=False, + ) + def forward(self, x_t, mean, mask, t): t = self.time_encoder(t) t = self.time_embed(t) @@ -152,9 +264,11 @@ def forward(self, x_t, mean, mask, t): mask_states = [mask] for block in self.input_blocks: - res_net_block, downsample = block + res_net_block, transformer, downsample = block x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) + hidden_states.append(x_t) if downsample is not None: @@ -162,20 +276,23 @@ def forward(self, x_t, mean, mask, t): mask = mask[:, :, ::2] mask_states.append(mask) - for _ in self.middle_blocks: - pass + for block in self.middle_blocks: + res_net_block, transformer = block + mask = mask_states[-1] + x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) for block in self.output_blocks: - res_net_block, upsample = block - + res_net_block, transformer, upsample = block + x_t = pack([x_t, hidden_states.pop()], "b * t")[0] mask = mask_states.pop() x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) if upsample is not None: x_t = upsample(x_t * mask) - output = self.conv_block(x_t) output = self.conv(x_t)