diff --git a/TTS/tts/configs/matcha_tts.py b/TTS/tts/configs/matcha_tts.py new file mode 100644 index 0000000000..15bb91b829 --- /dev/null +++ b/TTS/tts/configs/matcha_tts.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class MatchaTTSConfig(BaseTTSConfig): + model: str = "matcha_tts" + num_chars: int = None diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py new file mode 100644 index 0000000000..b47db51cf5 --- /dev/null +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -0,0 +1,299 @@ +import math +from einops import pack, rearrange +import torch +from torch import nn +import conformer + + +class PositionalEncoding(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + emb = math.log(10000) / (self.channels // 2 - 1) + emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class ConvBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, num_groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups, out_channels), + nn.Mish() + ) + + def forward(self, x, mask=None): + if mask is not None: + x = x * mask + output = self.block(x) + if mask is not None: + output = output * mask + return output + + +class ResNetBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8): + super().__init__() + self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.mlp = nn.Sequential( + nn.Mish(), + nn.Linear(time_embed_channels, out_channels) + ) + self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) + + def forward(self, x, mask, t): + h = self.block_1(x, mask) + h += self.mlp(t).unsqueeze(-1) + h = self.block_2(h, mask) + output = h + self.conv(x * mask) + return output + + +class Downsample1D(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1D(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + + +class ConformerBlock(conformer.ConformerBlock): + def __init__( + self, + dim: int, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + conv_expansion_factor: int = 2, + conv_kernel_size: int = 31, + attn_dropout: float = 0., + ff_dropout: float = 0., + conv_dropout: float = 0., + conv_causal: bool = False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward(self, x, mask,): + x = rearrange(x, "b c t -> b t c") + mask = rearrange(mask, "b 1 t -> b t") + output = super().forward(x=x, mask=mask.bool()) + return rearrange(output, "b t c -> b c t") + + +class UNet(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + num_blocks: int, + transformer_num_heads: int = 4, + transformer_dim_head: int = 64, + transformer_ff_mult: int = 1, + transformer_conv_expansion_factor: int = 2, + transformer_conv_kernel_size: int = 31, + transformer_dropout: float = 0.05, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_encoder = PositionalEncoding(in_channels) + time_embed_channels = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(in_channels, time_embed_channels), + nn.SiLU(), + nn.Linear(time_embed_channels, time_embed_channels), + ) + + self.input_blocks = nn.ModuleList([]) + block_in_channels = in_channels * 2 + block_out_channels = model_channels + for level in range(num_blocks): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_in_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + + if level != num_blocks - 1: + block.append(Downsample1D(block_out_channels)) + else: + block.append(None) + + block_in_channels = block_out_channels + self.input_blocks.append(block) + + self.middle_blocks = nn.ModuleList([]) + for i in range(2): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_out_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + + self.middle_blocks.append(block) + + self.output_blocks = nn.ModuleList([]) + block_in_channels = block_out_channels * 2 + block_out_channels = model_channels + for level in range(num_blocks): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_in_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + + if level != num_blocks - 1: + block.append(Upsample1D(block_out_channels)) + else: + block.append(None) + + block_in_channels = block_out_channels * 2 + self.output_blocks.append(block) + + self.conv_block = ConvBlock1D(model_channels, model_channels) + self.conv = nn.Conv1d(model_channels, self.out_channels, 1) + + def _create_transformer_block( + self, + dim, + dim_head: int = 64, + num_heads: int = 4, + ff_mult: int = 1, + conv_expansion_factor: int = 2, + conv_kernel_size: int = 31, + dropout: float = 0.05, + ): + return ConformerBlock( + dim=dim, + dim_head=dim_head, + heads=num_heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=dropout, + ff_dropout=dropout, + conv_dropout=dropout, + conv_causal=False, + ) + + def forward(self, x_t, mean, mask, t): + t = self.time_encoder(t) + t = self.time_embed(t) + + x_t = pack([x_t, mean], "b * t")[0] + + hidden_states = [] + mask_states = [mask] + + for block in self.input_blocks: + res_net_block, transformer, downsample = block + + x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) + + hidden_states.append(x_t) + + if downsample is not None: + x_t = downsample(x_t * mask) + mask = mask[:, :, ::2] + mask_states.append(mask) + + for block in self.middle_blocks: + res_net_block, transformer = block + mask = mask_states[-1] + x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) + + for block in self.output_blocks: + res_net_block, transformer, upsample = block + + x_t = pack([x_t, hidden_states.pop()], "b * t")[0] + mask = mask_states.pop() + x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) + + if upsample is not None: + x_t = upsample(x_t * mask) + + output = self.conv_block(x_t) + output = self.conv(x_t) + + return output * mask \ No newline at end of file diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py new file mode 100644 index 0000000000..c87da9d559 --- /dev/null +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -0,0 +1,32 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from TTS.tts.layers.matcha_tts.UNet import UNet + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.sigma_min = 1e-5 + self.predictor = UNet( + in_channels=80, + model_channels=256, + out_channels=80, + num_blocks=2 + ) + + def forward(self, x_1, mean, mask): + """ + Shapes: + - x_1: :math:`[B, C, T]` + - mean: :math:`[B, C ,T]` + - mask: :math:`[B, 1, T]` + """ + t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype) + x_0 = torch.randn_like(x_1) + x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 + u_t = x_1 - (1 - self.sigma_min) * x_0 + v_t = self.predictor(x_t, mean, mask, t.squeeze()) + loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) + return loss diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py new file mode 100644 index 0000000000..9bc3e0ffc4 --- /dev/null +++ b/TTS/tts/models/matcha_tts.py @@ -0,0 +1,85 @@ +from dataclasses import field +import math +import torch + +from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.layers.matcha_tts.decoder import Decoder +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import maximum_path, sequence_mask +from TTS.tts.utils.text.tokenizer import TTSTokenizer + + +class MatchaTTS(BaseTTS): + + def __init__( + self, + config: MatchaTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + ): + super().__init__(config, ap, tokenizer) + self.encoder = Encoder( + self.config.num_chars, + out_channels=80, + hidden_channels=192, + hidden_channels_dp=256, + encoder_type='rel_pos_transformer', + encoder_params={ + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) + + self.decoder = Decoder() + + def forward(self, x, x_lengths, y, y_lengths): + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + + o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None) + + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype) + attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2) + + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2)) + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y) + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) + logp = logp1 + logp2 + logp3 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + # Align encoded text with mel-spectrogram and get mu_y segment + c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2) + + _ = self.decoder(x_1=y, mean=c_mean, mask=y_mask) + + @torch.no_grad() + def inference(self): + pass + + @staticmethod + def init_from_config(config: "MatchaTTSConfig"): + pass + + def load_checkpoint(self, checkpoint_path): + pass diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py new file mode 100644 index 0000000000..5fbe95377f --- /dev/null +++ b/tests/tts_tests2/test_matcha_tts.py @@ -0,0 +1,36 @@ +import unittest + +import torch + +from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.models.matcha_tts import MatchaTTS + +torch.manual_seed(1) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = MatchaTTSConfig() + + +class TestMatchTTS(unittest.TestCase): + @staticmethod + def _create_inputs(batch_size=8): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) + return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + + def _test_forward(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) + config = MatchaTTSConfig(num_chars=32) + model = MatchaTTS(config).to(device) + + model.train() + + model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + + def test_forward(self): + self._test_forward(1) + self._test_forward(3)