-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'subuday/matcha_tts' into dev
- Loading branch information
Showing
5 changed files
with
461 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from dataclasses import dataclass, field | ||
|
||
from TTS.tts.configs.shared_configs import BaseTTSConfig | ||
|
||
|
||
@dataclass | ||
class MatchaTTSConfig(BaseTTSConfig): | ||
model: str = "matcha_tts" | ||
num_chars: int = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,299 @@ | ||
import math | ||
from einops import pack, rearrange | ||
import torch | ||
from torch import nn | ||
import conformer | ||
|
||
|
||
class PositionalEncoding(torch.nn.Module): | ||
def __init__(self, channels): | ||
super().__init__() | ||
self.channels = channels | ||
|
||
def forward(self, x, scale=1000): | ||
if x.ndim < 1: | ||
x = x.unsqueeze(0) | ||
emb = math.log(10000) / (self.channels // 2 - 1) | ||
emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb) | ||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) | ||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | ||
return emb | ||
|
||
class ConvBlock1D(nn.Module): | ||
def __init__(self, in_channels, out_channels, num_groups=8): | ||
super().__init__() | ||
self.block = nn.Sequential( | ||
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), | ||
nn.GroupNorm(num_groups, out_channels), | ||
nn.Mish() | ||
) | ||
|
||
def forward(self, x, mask=None): | ||
if mask is not None: | ||
x = x * mask | ||
output = self.block(x) | ||
if mask is not None: | ||
output = output * mask | ||
return output | ||
|
||
|
||
class ResNetBlock1D(nn.Module): | ||
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8): | ||
super().__init__() | ||
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) | ||
self.mlp = nn.Sequential( | ||
nn.Mish(), | ||
nn.Linear(time_embed_channels, out_channels) | ||
) | ||
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups) | ||
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) | ||
|
||
def forward(self, x, mask, t): | ||
h = self.block_1(x, mask) | ||
h += self.mlp(t).unsqueeze(-1) | ||
h = self.block_2(h, mask) | ||
output = h + self.conv(x * mask) | ||
return output | ||
|
||
|
||
class Downsample1D(nn.Module): | ||
def __init__(self, channels): | ||
super().__init__() | ||
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class Upsample1D(nn.Module): | ||
def __init__(self, channels): | ||
super().__init__() | ||
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1) | ||
|
||
def forward(self, x): | ||
return self.conv(x) | ||
|
||
|
||
class ConformerBlock(conformer.ConformerBlock): | ||
def __init__( | ||
self, | ||
dim: int, | ||
dim_head: int = 64, | ||
heads: int = 8, | ||
ff_mult: int = 4, | ||
conv_expansion_factor: int = 2, | ||
conv_kernel_size: int = 31, | ||
attn_dropout: float = 0., | ||
ff_dropout: float = 0., | ||
conv_dropout: float = 0., | ||
conv_causal: bool = False, | ||
): | ||
super().__init__( | ||
dim=dim, | ||
dim_head=dim_head, | ||
heads=heads, | ||
ff_mult=ff_mult, | ||
conv_expansion_factor=conv_expansion_factor, | ||
conv_kernel_size=conv_kernel_size, | ||
attn_dropout=attn_dropout, | ||
ff_dropout=ff_dropout, | ||
conv_dropout=conv_dropout, | ||
conv_causal=conv_causal, | ||
) | ||
|
||
def forward(self, x, mask,): | ||
x = rearrange(x, "b c t -> b t c") | ||
mask = rearrange(mask, "b 1 t -> b t") | ||
output = super().forward(x=x, mask=mask.bool()) | ||
return rearrange(output, "b t c -> b c t") | ||
|
||
|
||
class UNet(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
model_channels: int, | ||
out_channels: int, | ||
num_blocks: int, | ||
transformer_num_heads: int = 4, | ||
transformer_dim_head: int = 64, | ||
transformer_ff_mult: int = 1, | ||
transformer_conv_expansion_factor: int = 2, | ||
transformer_conv_kernel_size: int = 31, | ||
transformer_dropout: float = 0.05, | ||
): | ||
super().__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
|
||
self.time_encoder = PositionalEncoding(in_channels) | ||
time_embed_channels = model_channels * 4 | ||
self.time_embed = nn.Sequential( | ||
nn.Linear(in_channels, time_embed_channels), | ||
nn.SiLU(), | ||
nn.Linear(time_embed_channels, time_embed_channels), | ||
) | ||
|
||
self.input_blocks = nn.ModuleList([]) | ||
block_in_channels = in_channels * 2 | ||
block_out_channels = model_channels | ||
for level in range(num_blocks): | ||
block = nn.ModuleList([]) | ||
|
||
block.append( | ||
ResNetBlock1D( | ||
in_channels=block_in_channels, | ||
out_channels=block_out_channels, | ||
time_embed_channels=time_embed_channels | ||
) | ||
) | ||
|
||
block.append( | ||
self._create_transformer_block( | ||
block_out_channels, | ||
dim_head=transformer_dim_head, | ||
num_heads=transformer_num_heads, | ||
ff_mult=transformer_ff_mult, | ||
conv_expansion_factor=transformer_conv_expansion_factor, | ||
conv_kernel_size=transformer_conv_kernel_size, | ||
dropout=transformer_dropout, | ||
) | ||
) | ||
|
||
if level != num_blocks - 1: | ||
block.append(Downsample1D(block_out_channels)) | ||
else: | ||
block.append(None) | ||
|
||
block_in_channels = block_out_channels | ||
self.input_blocks.append(block) | ||
|
||
self.middle_blocks = nn.ModuleList([]) | ||
for i in range(2): | ||
block = nn.ModuleList([]) | ||
|
||
block.append( | ||
ResNetBlock1D( | ||
in_channels=block_out_channels, | ||
out_channels=block_out_channels, | ||
time_embed_channels=time_embed_channels | ||
) | ||
) | ||
|
||
block.append( | ||
self._create_transformer_block( | ||
block_out_channels, | ||
dim_head=transformer_dim_head, | ||
num_heads=transformer_num_heads, | ||
ff_mult=transformer_ff_mult, | ||
conv_expansion_factor=transformer_conv_expansion_factor, | ||
conv_kernel_size=transformer_conv_kernel_size, | ||
dropout=transformer_dropout, | ||
) | ||
) | ||
|
||
self.middle_blocks.append(block) | ||
|
||
self.output_blocks = nn.ModuleList([]) | ||
block_in_channels = block_out_channels * 2 | ||
block_out_channels = model_channels | ||
for level in range(num_blocks): | ||
block = nn.ModuleList([]) | ||
|
||
block.append( | ||
ResNetBlock1D( | ||
in_channels=block_in_channels, | ||
out_channels=block_out_channels, | ||
time_embed_channels=time_embed_channels | ||
) | ||
) | ||
|
||
block.append( | ||
self._create_transformer_block( | ||
block_out_channels, | ||
dim_head=transformer_dim_head, | ||
num_heads=transformer_num_heads, | ||
ff_mult=transformer_ff_mult, | ||
conv_expansion_factor=transformer_conv_expansion_factor, | ||
conv_kernel_size=transformer_conv_kernel_size, | ||
dropout=transformer_dropout, | ||
) | ||
) | ||
|
||
if level != num_blocks - 1: | ||
block.append(Upsample1D(block_out_channels)) | ||
else: | ||
block.append(None) | ||
|
||
block_in_channels = block_out_channels * 2 | ||
self.output_blocks.append(block) | ||
|
||
self.conv_block = ConvBlock1D(model_channels, model_channels) | ||
self.conv = nn.Conv1d(model_channels, self.out_channels, 1) | ||
|
||
def _create_transformer_block( | ||
self, | ||
dim, | ||
dim_head: int = 64, | ||
num_heads: int = 4, | ||
ff_mult: int = 1, | ||
conv_expansion_factor: int = 2, | ||
conv_kernel_size: int = 31, | ||
dropout: float = 0.05, | ||
): | ||
return ConformerBlock( | ||
dim=dim, | ||
dim_head=dim_head, | ||
heads=num_heads, | ||
ff_mult=ff_mult, | ||
conv_expansion_factor=conv_expansion_factor, | ||
conv_kernel_size=conv_kernel_size, | ||
attn_dropout=dropout, | ||
ff_dropout=dropout, | ||
conv_dropout=dropout, | ||
conv_causal=False, | ||
) | ||
|
||
def forward(self, x_t, mean, mask, t): | ||
t = self.time_encoder(t) | ||
t = self.time_embed(t) | ||
|
||
x_t = pack([x_t, mean], "b * t")[0] | ||
|
||
hidden_states = [] | ||
mask_states = [mask] | ||
|
||
for block in self.input_blocks: | ||
res_net_block, transformer, downsample = block | ||
|
||
x_t = res_net_block(x_t, mask, t) | ||
x_t = transformer(x_t, mask) | ||
|
||
hidden_states.append(x_t) | ||
|
||
if downsample is not None: | ||
x_t = downsample(x_t * mask) | ||
mask = mask[:, :, ::2] | ||
mask_states.append(mask) | ||
|
||
for block in self.middle_blocks: | ||
res_net_block, transformer = block | ||
mask = mask_states[-1] | ||
x_t = res_net_block(x_t, mask, t) | ||
x_t = transformer(x_t, mask) | ||
|
||
for block in self.output_blocks: | ||
res_net_block, transformer, upsample = block | ||
|
||
x_t = pack([x_t, hidden_states.pop()], "b * t")[0] | ||
mask = mask_states.pop() | ||
x_t = res_net_block(x_t, mask, t) | ||
x_t = transformer(x_t, mask) | ||
|
||
if upsample is not None: | ||
x_t = upsample(x_t * mask) | ||
|
||
output = self.conv_block(x_t) | ||
output = self.conv(x_t) | ||
|
||
return output * mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
|
||
from TTS.tts.layers.matcha_tts.UNet import UNet | ||
|
||
|
||
class Decoder(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.sigma_min = 1e-5 | ||
self.predictor = UNet( | ||
in_channels=80, | ||
model_channels=256, | ||
out_channels=80, | ||
num_blocks=2 | ||
) | ||
|
||
def forward(self, x_1, mean, mask): | ||
""" | ||
Shapes: | ||
- x_1: :math:`[B, C, T]` | ||
- mean: :math:`[B, C ,T]` | ||
- mask: :math:`[B, 1, T]` | ||
""" | ||
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype) | ||
x_0 = torch.randn_like(x_1) | ||
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 | ||
u_t = x_1 - (1 - self.sigma_min) * x_0 | ||
v_t = self.predictor(x_t, mean, mask, t.squeeze()) | ||
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) | ||
return loss |
Oops, something went wrong.