From 1b6d3ebd3317f1d4017104b5208cbc8210780e0e Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 20 Nov 2024 18:39:55 +0100 Subject: [PATCH 01/25] refactor(xtts): remove duplicate hifigan generator --- TTS/tts/layers/xtts/hifigan_decoder.py | 609 +----------------------- TTS/vocoder/models/hifigan_generator.py | 13 + 2 files changed, 15 insertions(+), 607 deletions(-) diff --git a/TTS/tts/layers/xtts/hifigan_decoder.py b/TTS/tts/layers/xtts/hifigan_decoder.py index 5ef0030b8b..2e6ac01a87 100644 --- a/TTS/tts/layers/xtts/hifigan_decoder.py +++ b/TTS/tts/layers/xtts/hifigan_decoder.py @@ -1,618 +1,13 @@ import logging import torch -import torchaudio -from torch import nn -from torch.nn import Conv1d, ConvTranspose1d -from torch.nn import functional as F -from torch.nn.utils.parametrizations import weight_norm -from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec -from TTS.utils.generic_utils import is_pytorch_at_least_2_4 -from TTS.vocoder.models.hifigan_generator import get_padding +from TTS.encoder.models.resnet import ResNetSpeakerEncoder +from TTS.vocoder.models.hifigan_generator import HifiganGenerator logger = logging.getLogger(__name__) -LRELU_SLOPE = 0.1 - - -class ResBlock1(torch.nn.Module): - """Residual Block Type 1. It has 3 convolutional layers in each convolutional block. - - Network:: - - x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o - |--------------------------------------------------------------------------------------------------| - - - Args: - channels (int): number of hidden channels for the convolutional layers. - kernel_size (int): size of the convolution filter in each layer. - dilations (list): list of dilation value for each conv layer in a block. - """ - - def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): - super().__init__() - self.convs1 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[2], - padding=get_padding(kernel_size, dilation[2]), - ) - ), - ] - ) - - self.convs2 = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=1, - padding=get_padding(kernel_size, 1), - ) - ), - ] - ) - - def forward(self, x): - """ - Args: - x (Tensor): input tensor. - Returns: - Tensor: output tensor. - Shapes: - x: [B, C, T] - """ - for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs1: - remove_parametrizations(l, "weight") - for l in self.convs2: - remove_parametrizations(l, "weight") - - -class ResBlock2(torch.nn.Module): - """Residual Block Type 2. It has 1 convolutional layers in each convolutional block. - - Network:: - - x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o - |---------------------------------------------------| - - - Args: - channels (int): number of hidden channels for the convolutional layers. - kernel_size (int): size of the convolution filter in each layer. - dilations (list): list of dilation value for each conv layer in a block. - """ - - def __init__(self, channels, kernel_size=3, dilation=(1, 3)): - super().__init__() - self.convs = nn.ModuleList( - [ - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[0], - padding=get_padding(kernel_size, dilation[0]), - ) - ), - weight_norm( - Conv1d( - channels, - channels, - kernel_size, - 1, - dilation=dilation[1], - padding=get_padding(kernel_size, dilation[1]), - ) - ), - ] - ) - - def forward(self, x): - for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x - - def remove_weight_norm(self): - for l in self.convs: - remove_parametrizations(l, "weight") - - -class HifiganGenerator(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - resblock_type, - resblock_dilation_sizes, - resblock_kernel_sizes, - upsample_kernel_sizes, - upsample_initial_channel, - upsample_factors, - inference_padding=5, - cond_channels=0, - conv_pre_weight_norm=True, - conv_post_weight_norm=True, - conv_post_bias=True, - cond_in_each_up_layer=False, - ): - r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) - - Network: - x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o - .. -> zI ---| - resblockN_kNx1 -> zN ---' - - Args: - in_channels (int): number of input tensor channels. - out_channels (int): number of output tensor channels. - resblock_type (str): type of the `ResBlock`. '1' or '2'. - resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. - resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. - upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. - upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 - for each consecutive upsampling layer. - upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. - inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. - """ - super().__init__() - self.inference_padding = inference_padding - self.num_kernels = len(resblock_kernel_sizes) - self.num_upsamples = len(upsample_factors) - self.cond_in_each_up_layer = cond_in_each_up_layer - - # initial upsampling layers - self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) - resblock = ResBlock1 if resblock_type == "1" else ResBlock2 - # upsampling layers - self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): - self.ups.append( - weight_norm( - ConvTranspose1d( - upsample_initial_channel // (2**i), - upsample_initial_channel // (2 ** (i + 1)), - k, - u, - padding=(k - u) // 2, - ) - ) - ) - # MRF blocks - self.resblocks = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock(ch, k, d)) - # post convolution layer - self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3, bias=conv_post_bias)) - if cond_channels > 0: - self.cond_layer = nn.Conv1d(cond_channels, upsample_initial_channel, 1) - - if not conv_pre_weight_norm: - remove_parametrizations(self.conv_pre, "weight") - - if not conv_post_weight_norm: - remove_parametrizations(self.conv_post, "weight") - - if self.cond_in_each_up_layer: - self.conds = nn.ModuleList() - for i in range(len(self.ups)): - ch = upsample_initial_channel // (2 ** (i + 1)) - self.conds.append(nn.Conv1d(cond_channels, ch, 1)) - - def forward(self, x, g=None): - """ - Args: - x (Tensor): feature input tensor. - g (Tensor): global conditioning input tensor. - - Returns: - Tensor: output waveform. - - Shapes: - x: [B, C, T] - Tensor: [B, 1, T] - """ - o = self.conv_pre(x) - if hasattr(self, "cond_layer"): - o = o + self.cond_layer(g) - for i in range(self.num_upsamples): - o = F.leaky_relu(o, LRELU_SLOPE) - o = self.ups[i](o) - - if self.cond_in_each_up_layer: - o = o + self.conds[i](g) - - z_sum = None - for j in range(self.num_kernels): - if z_sum is None: - z_sum = self.resblocks[i * self.num_kernels + j](o) - else: - z_sum += self.resblocks[i * self.num_kernels + j](o) - o = z_sum / self.num_kernels - o = F.leaky_relu(o) - o = self.conv_post(o) - o = torch.tanh(o) - return o - - @torch.no_grad() - def inference(self, c): - """ - Args: - x (Tensor): conditioning input tensor. - - Returns: - Tensor: output waveform. - - Shapes: - x: [B, C, T] - Tensor: [B, 1, T] - """ - c = c.to(self.conv_pre.weight.device) - c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") - return self.forward(c) - - def remove_weight_norm(self): - logger.info("Removing weight norm...") - for l in self.ups: - remove_parametrizations(l, "weight") - for l in self.resblocks: - l.remove_weight_norm() - remove_parametrizations(self.conv_pre, "weight") - remove_parametrizations(self.conv_post, "weight") - - def load_checkpoint( - self, config, checkpoint_path, eval=False, cache=False - ): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device("cpu"), weights_only=is_pytorch_at_least_2_4()) - self.load_state_dict(state["model"]) - if eval: - self.eval() - assert not self.training - self.remove_weight_norm() - - -class SELayer(nn.Module): - def __init__(self, channel, reduction=8): - super(SELayer, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction), - nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel), - nn.Sigmoid(), - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - return x * y - - -class SEBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8): - super(SEBasicBlock, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - self.se = SELayer(planes, reduction) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.relu(out) - out = self.bn1(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.se(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - return out - - -def set_init_dict(model_dict, checkpoint_state, c): - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint_state.items(): - if k not in model_dict: - logger.warning("Layer missing in the model definition: %s", k) - # 1. filter out unnecessary keys - pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} - # 2. filter out different size layers - pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} - # 3. skip reinit layers - if c.has("reinit_layers") and c.reinit_layers is not None: - for reinit_layer_name in c.reinit_layers: - pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} - # 4. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) - return model_dict - - -class PreEmphasis(nn.Module): - def __init__(self, coefficient=0.97): - super().__init__() - self.coefficient = coefficient - self.register_buffer("filter", torch.FloatTensor([-self.coefficient, 1.0]).unsqueeze(0).unsqueeze(0)) - - def forward(self, x): - assert len(x.size()) == 2 - - x = torch.nn.functional.pad(x.unsqueeze(1), (1, 0), "reflect") - return torch.nn.functional.conv1d(x, self.filter).squeeze(1) - - -class ResNetSpeakerEncoder(nn.Module): - """This is copied from 🐸TTS to remove it from the dependencies.""" - - # pylint: disable=W0102 - def __init__( - self, - input_dim=64, - proj_dim=512, - layers=[3, 4, 6, 3], - num_filters=[32, 64, 128, 256], - encoder_type="ASP", - log_input=False, - use_torch_spec=False, - audio_config=None, - ): - super(ResNetSpeakerEncoder, self).__init__() - - self.encoder_type = encoder_type - self.input_dim = input_dim - self.log_input = log_input - self.use_torch_spec = use_torch_spec - self.audio_config = audio_config - self.proj_dim = proj_dim - - self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size=3, stride=1, padding=1) - self.relu = nn.ReLU(inplace=True) - self.bn1 = nn.BatchNorm2d(num_filters[0]) - - self.inplanes = num_filters[0] - self.layer1 = self.create_layer(SEBasicBlock, num_filters[0], layers[0]) - self.layer2 = self.create_layer(SEBasicBlock, num_filters[1], layers[1], stride=(2, 2)) - self.layer3 = self.create_layer(SEBasicBlock, num_filters[2], layers[2], stride=(2, 2)) - self.layer4 = self.create_layer(SEBasicBlock, num_filters[3], layers[3], stride=(2, 2)) - - self.instancenorm = nn.InstanceNorm1d(input_dim) - - if self.use_torch_spec: - self.torch_spec = torch.nn.Sequential( - PreEmphasis(audio_config["preemphasis"]), - torchaudio.transforms.MelSpectrogram( - sample_rate=audio_config["sample_rate"], - n_fft=audio_config["fft_size"], - win_length=audio_config["win_length"], - hop_length=audio_config["hop_length"], - window_fn=torch.hamming_window, - n_mels=audio_config["num_mels"], - ), - ) - - else: - self.torch_spec = None - - outmap_size = int(self.input_dim / 8) - - self.attention = nn.Sequential( - nn.Conv1d(num_filters[3] * outmap_size, 128, kernel_size=1), - nn.ReLU(), - nn.BatchNorm1d(128), - nn.Conv1d(128, num_filters[3] * outmap_size, kernel_size=1), - nn.Softmax(dim=2), - ) - - if self.encoder_type == "SAP": - out_dim = num_filters[3] * outmap_size - elif self.encoder_type == "ASP": - out_dim = num_filters[3] * outmap_size * 2 - else: - raise ValueError("Undefined encoder") - - self.fc = nn.Linear(out_dim, proj_dim) - - self._init_layers() - - def _init_layers(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - nn.init.constant_(m.bias, 0) - - def create_layer(self, block, planes, blocks, stride=1): - downsample = None - if stride != 1 or self.inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion), - ) - - layers = [] - layers.append(block(self.inplanes, planes, stride, downsample)) - self.inplanes = planes * block.expansion - for _ in range(1, blocks): - layers.append(block(self.inplanes, planes)) - - return nn.Sequential(*layers) - - # pylint: disable=R0201 - def new_parameter(self, *size): - out = nn.Parameter(torch.FloatTensor(*size)) - nn.init.xavier_normal_(out) - return out - - def forward(self, x, l2_norm=False): - """Forward pass of the model. - - Args: - x (Tensor): Raw waveform signal or spectrogram frames. If input is a waveform, `torch_spec` must be `True` - to compute the spectrogram on-the-fly. - l2_norm (bool): Whether to L2-normalize the outputs. - - Shapes: - - x: :math:`(N, 1, T_{in})` or :math:`(N, D_{spec}, T_{in})` - """ - x.squeeze_(1) - # if you torch spec compute it otherwise use the mel spec computed by the AP - if self.use_torch_spec: - x = self.torch_spec(x) - - if self.log_input: - x = (x + 1e-6).log() - x = self.instancenorm(x).unsqueeze(1) - - x = self.conv1(x) - x = self.relu(x) - x = self.bn1(x) - - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - - x = x.reshape(x.size()[0], -1, x.size()[-1]) - - w = self.attention(x) - - if self.encoder_type == "SAP": - x = torch.sum(x * w, dim=2) - elif self.encoder_type == "ASP": - mu = torch.sum(x * w, dim=2) - sg = torch.sqrt((torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-5)) - x = torch.cat((mu, sg), 1) - - x = x.view(x.size()[0], -1) - x = self.fc(x) - - if l2_norm: - x = torch.nn.functional.normalize(x, p=2, dim=1) - return x - - def load_checkpoint( - self, - checkpoint_path: str, - eval: bool = False, - use_cuda: bool = False, - criterion=None, - cache=False, - ): - state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) - try: - self.load_state_dict(state["model"]) - logger.info("Model fully restored.") - except (KeyError, RuntimeError) as error: - # If eval raise the error - if eval: - raise error - - logger.info("Partial model initialization.") - model_dict = self.state_dict() - model_dict = set_init_dict(model_dict, state["model"]) - self.load_state_dict(model_dict) - del model_dict - - # load the criterion for restore_path - if criterion is not None and "criterion" in state: - try: - criterion.load_state_dict(state["criterion"]) - except (KeyError, RuntimeError) as error: - logger.exception("Criterion load ignored because of: %s", error) - - if use_cuda: - self.cuda() - if criterion is not None: - criterion = criterion.cuda() - - if eval: - self.eval() - assert not self.training - - if not eval: - return criterion, state["step"] - return criterion - class HifiDecoder(torch.nn.Module): def __init__( diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index afdd59a859..8273d02037 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -178,6 +178,7 @@ def __init__( conv_pre_weight_norm=True, conv_post_weight_norm=True, conv_post_bias=True, + cond_in_each_up_layer=False, ): r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) @@ -202,6 +203,8 @@ def __init__( self.inference_padding = inference_padding self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_factors) + self.cond_in_each_up_layer = cond_in_each_up_layer + # initial upsampling layers self.conv_pre = weight_norm(Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) resblock = ResBlock1 if resblock_type == "1" else ResBlock2 @@ -236,6 +239,12 @@ def __init__( if not conv_post_weight_norm: remove_parametrizations(self.conv_post, "weight") + if self.cond_in_each_up_layer: + self.conds = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + self.conds.append(nn.Conv1d(cond_channels, ch, 1)) + def forward(self, x, g=None): """ Args: @@ -255,6 +264,10 @@ def forward(self, x, g=None): for i in range(self.num_upsamples): o = F.leaky_relu(o, LRELU_SLOPE) o = self.ups[i](o) + + if self.cond_in_each_up_layer: + o = o + self.conds[i](g) + z_sum = None for j in range(self.num_kernels): if z_sum is None: From 1f27f994a1d7f30400172c341b0256a350749055 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 20 Nov 2024 18:40:28 +0100 Subject: [PATCH 02/25] refactor(utils): remove duplicate set_partial_state_dict --- TTS/encoder/models/base_encoder.py | 4 ++-- TTS/utils/generic_utils.py | 19 ------------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/TTS/encoder/models/base_encoder.py b/TTS/encoder/models/base_encoder.py index f7137c2186..2082019aad 100644 --- a/TTS/encoder/models/base_encoder.py +++ b/TTS/encoder/models/base_encoder.py @@ -5,10 +5,10 @@ import torchaudio from coqpit import Coqpit from torch import nn +from trainer.generic_utils import set_partial_state_dict from trainer.io import load_fsspec from TTS.encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss -from TTS.utils.generic_utils import set_init_dict logger = logging.getLogger(__name__) @@ -130,7 +130,7 @@ def load_checkpoint( logger.info("Partial model initialization.") model_dict = self.state_dict() - model_dict = set_init_dict(model_dict, state["model"], c) + model_dict = set_partial_state_dict(model_dict, state["model"], config) self.load_state_dict(model_dict) del model_dict diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 3ee285232f..c38282248d 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -54,25 +54,6 @@ def get_import_path(obj: object) -> str: return ".".join([type(obj).__module__, type(obj).__name__]) -def set_init_dict(model_dict, checkpoint_state, c): - # Partial initialization: if there is a mismatch with new and old layer, it is skipped. - for k, v in checkpoint_state.items(): - if k not in model_dict: - logger.warning("Layer missing in the model finition %s", k) - # 1. filter out unnecessary keys - pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} - # 2. filter out different size layers - pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} - # 3. skip reinit layers - if c.has("reinit_layers") and c.reinit_layers is not None: - for reinit_layer_name in c.reinit_layers: - pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} - # 4. overwrite entries in the existing state dict - model_dict.update(pretrained_dict) - logger.info("%d / %d layers are restored.", len(pretrained_dict), len(model_dict)) - return model_dict - - def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict: """Format kwargs to hande auxilary inputs to models. From 66701e1e5120d1db44023de09c5b46f0aa56eba3 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 12:21:38 +0100 Subject: [PATCH 03/25] refactor(xtts): reuse functions/classes from tortoise --- TTS/tts/layers/xtts/latent_encoder.py | 23 +---------------------- TTS/tts/layers/xtts/perceiver_encoder.py | 8 ++------ 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py index f9d62a36f1..7d385ec46a 100644 --- a/TTS/tts/layers/xtts/latent_encoder.py +++ b/TTS/tts/layers/xtts/latent_encoder.py @@ -6,10 +6,7 @@ from torch import nn from torch.nn import functional as F - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) +from TTS.tts.layers.tortoise.arch_utils import normalization, zero_module def conv_nd(dims, *args, **kwargs): @@ -22,24 +19,6 @@ def conv_nd(dims, *args, **kwargs): raise ValueError(f"unsupported dimensions: {dims}") -def normalization(channels): - groups = 32 - if channels <= 16: - groups = 8 - elif channels <= 64: - groups = 16 - while channels % groups != 0: - groups = int(groups / 2) - assert groups > 2 - return GroupNorm32(groups, channels) - - -def zero_module(module): - for p in module.parameters(): - p.detach().zero_() - return module - - class QKVAttention(nn.Module): def __init__(self, n_heads): super().__init__() diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py index f4b6e84123..4b42a0e467 100644 --- a/TTS/tts/layers/xtts/perceiver_encoder.py +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -9,6 +9,8 @@ from einops.layers.torch import Rearrange from torch import einsum, nn +from TTS.tts.layers.tortoise.transformer import GEGLU + def exists(val): return val is not None @@ -194,12 +196,6 @@ def forward(self, x): return super().forward(causal_padded_x) -class GEGLU(nn.Module): - def forward(self, x): - x, gate = x.chunk(2, dim=-1) - return F.gelu(gate) * x - - def FeedForward(dim, mult=4, causal_conv=False): dim_inner = int(dim * mult * 2 / 3) From 4ba83f42ab4287430f47f9f17031222e7bbb3086 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 12:28:03 +0100 Subject: [PATCH 04/25] chore(tortoise): remove unused AudioMiniEncoder There's one in tortoise.classifier that's actually used --- TTS/tts/layers/tortoise/arch_utils.py | 108 -------------------------- 1 file changed, 108 deletions(-) diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index 8eda251f93..c9abcf6094 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -185,114 +185,6 @@ def forward(self, x): return self.op(x) -class ResBlock(nn.Module): - def __init__( - self, - channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - up=False, - down=False, - kernel_size=3, - ): - super().__init__() - self.channels = channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_scale_shift_norm = use_scale_shift_norm - padding = 1 if kernel_size == 3 else 2 - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False) - self.x_upd = Upsample(channels, False) - elif down: - self.h_upd = Downsample(channels, False) - self.x_upd = Downsample(channels, False) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module(nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding) - else: - self.skip_connection = nn.Conv1d(channels, self.out_channels, 1) - - def forward(self, x): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AudioMiniEncoder(nn.Module): - def __init__( - self, - spec_dim, - embedding_dim, - base_channels=128, - depth=2, - resnet_blocks=2, - attn_blocks=4, - num_attn_heads=4, - dropout=0, - downsample_factor=2, - kernel_size=3, - ): - super().__init__() - self.init = nn.Sequential(nn.Conv1d(spec_dim, base_channels, 3, padding=1)) - ch = base_channels - res = [] - for l in range(depth): - for r in range(resnet_blocks): - res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) - res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) - ch *= 2 - self.res = nn.Sequential(*res) - self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) - attn = [] - for a in range(attn_blocks): - attn.append( - AttentionBlock( - embedding_dim, - num_attn_heads, - ) - ) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - - def forward(self, x): - h = self.init(x) - h = self.res(h) - h = self.final(h) - h = self.attn(h) - return h[:, :, 0] - - DEFAULT_MEL_NORM_FILE = "https://github.com/coqui-ai/TTS/releases/download/v0.14.1_models/mel_norms.pth" From 705551c60c84ff8856efc3cf428ecf817a4f7f72 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 12:40:12 +0100 Subject: [PATCH 05/25] refactor(tortoise): remove unused do_checkpoint arguments These are assigned but not used for anything. --- TTS/tts/layers/tortoise/arch_utils.py | 2 -- TTS/tts/layers/tortoise/autoregressive.py | 2 -- TTS/tts/layers/tortoise/classifier.py | 6 ++---- TTS/tts/layers/tortoise/diffusion_decoder.py | 5 ----- 4 files changed, 2 insertions(+), 13 deletions(-) diff --git a/TTS/tts/layers/tortoise/arch_utils.py b/TTS/tts/layers/tortoise/arch_utils.py index c9abcf6094..4c3733e691 100644 --- a/TTS/tts/layers/tortoise/arch_utils.py +++ b/TTS/tts/layers/tortoise/arch_utils.py @@ -93,12 +93,10 @@ def __init__( channels, num_heads=1, num_head_channels=-1, - do_checkpoint=True, relative_pos_embeddings=False, ): super().__init__() self.channels = channels - self.do_checkpoint = do_checkpoint if num_head_channels == -1: self.num_heads = num_heads else: diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index aaae695516..e3ffd4d1f6 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -175,7 +175,6 @@ def __init__( embedding_dim, attn_blocks=6, num_attn_heads=4, - do_checkpointing=False, mean=False, ): super().__init__() @@ -185,7 +184,6 @@ def __init__( attn.append(AttentionBlock(embedding_dim, num_attn_heads)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim - self.do_checkpointing = do_checkpointing self.mean = mean def forward(self, x): diff --git a/TTS/tts/layers/tortoise/classifier.py b/TTS/tts/layers/tortoise/classifier.py index 8764bb070b..c72834e9a8 100644 --- a/TTS/tts/layers/tortoise/classifier.py +++ b/TTS/tts/layers/tortoise/classifier.py @@ -16,7 +16,6 @@ def __init__( up=False, down=False, kernel_size=3, - do_checkpoint=True, ): super().__init__() self.channels = channels @@ -24,7 +23,6 @@ def __init__( self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm - self.do_checkpoint = do_checkpoint padding = 1 if kernel_size == 3 else 2 self.in_layers = nn.Sequential( @@ -92,14 +90,14 @@ def __init__( self.layers = depth for l in range(depth): for r in range(resnet_blocks): - res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(ResBlock(ch, dropout, kernel_size=kernel_size)) res.append(Downsample(ch, use_conv=True, out_channels=ch * 2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) self.final = nn.Sequential(normalization(ch), nn.SiLU(), nn.Conv1d(ch, embedding_dim, 1)) attn = [] for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads, do_checkpoint=False)) + attn.append(AttentionBlock(embedding_dim, num_attn_heads)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim diff --git a/TTS/tts/layers/tortoise/diffusion_decoder.py b/TTS/tts/layers/tortoise/diffusion_decoder.py index f71eaf1718..15bbfb7121 100644 --- a/TTS/tts/layers/tortoise/diffusion_decoder.py +++ b/TTS/tts/layers/tortoise/diffusion_decoder.py @@ -196,31 +196,26 @@ def __init__( model_channels * 2, num_heads, relative_pos_embeddings=True, - do_checkpoint=False, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, - do_checkpoint=False, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, - do_checkpoint=False, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, - do_checkpoint=False, ), AttentionBlock( model_channels * 2, num_heads, relative_pos_embeddings=True, - do_checkpoint=False, ), ) self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1)) From 5ffc0543b76858ee5a25fc60c3de9d0369e43dd5 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 13:06:20 +0100 Subject: [PATCH 06/25] refactor(bark): remove custom layer norm Pytorch LayerNorm supports bias=False since version 2.1 --- TTS/tts/layers/bark/model.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/TTS/tts/layers/bark/model.py b/TTS/tts/layers/bark/model.py index 68c50dbdbd..54a9cecec0 100644 --- a/TTS/tts/layers/bark/model.py +++ b/TTS/tts/layers/bark/model.py @@ -12,18 +12,6 @@ from torch.nn import functional as F -class LayerNorm(nn.Module): - """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" - - def __init__(self, ndim, bias): - super().__init__() - self.weight = nn.Parameter(torch.ones(ndim)) - self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None - - def forward(self, x): - return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) - - class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() @@ -119,9 +107,9 @@ def forward(self, x): class Block(nn.Module): def __init__(self, config, layer_idx): super().__init__() - self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias) self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias) self.mlp = MLP(config) self.layer_idx = layer_idx @@ -158,7 +146,7 @@ def __init__(self, config): wpe=nn.Embedding(config.block_size, config.n_embd), drop=nn.Dropout(config.dropout), h=nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]), - ln_f=LayerNorm(config.n_embd, bias=config.bias), + ln_f=nn.LayerNorm(config.n_embd, bias=config.bias), ) ) self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False) From 490c973371c4a5ae345982325324efd0ece7f4af Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 15:05:37 +0100 Subject: [PATCH 07/25] refactor(xtts): use position embedding from tortoise --- TTS/tts/layers/tortoise/autoregressive.py | 15 ++++++++++---- TTS/tts/layers/xtts/gpt.py | 24 +---------------------- 2 files changed, 12 insertions(+), 27 deletions(-) diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index e3ffd4d1f6..3463e63b39 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -1,5 +1,6 @@ # AGPL: a notification must be added stating that changes have been made to that file. import functools +import random from typing import Optional import torch @@ -123,7 +124,7 @@ def forward( else: emb = self.embeddings(input_ids) emb = emb + self.text_pos_embedding.get_fixed_embedding( - attention_mask.shape[1] - mel_len, attention_mask.device + attention_mask.shape[1] - (mel_len + 1), attention_mask.device ) transformer_outputs = self.transformer( @@ -196,18 +197,24 @@ def forward(self, x): class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=0.02): + def __init__(self, seq_len, model_dim, init=0.02, relative=False): super().__init__() self.emb = nn.Embedding(seq_len, model_dim) # Initializing this way is standard for GPT-2 self.emb.weight.data.normal_(mean=0.0, std=init) + self.relative = relative + self.seq_len = seq_len def forward(self, x): sl = x.shape[1] - return self.emb(torch.arange(0, sl, device=x.device)) + if self.relative: + start = random.randint(sl, self.seq_len) - sl + return self.emb(torch.arange(start, start + sl, device=x.device)) + else: + return self.emb(torch.arange(0, sl, device=x.device)) def get_fixed_embedding(self, ind, dev): - return self.emb(torch.arange(0, ind, device=dev))[ind - 1 : ind] + return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index b3c3b31b47..f93287619e 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from transformers import GPT2Config -from TTS.tts.layers.tortoise.autoregressive import _prepare_attention_mask_for_generation +from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler @@ -18,28 +18,6 @@ def null_position_embeddings(range, dim): return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) -class LearnedPositionEmbeddings(nn.Module): - def __init__(self, seq_len, model_dim, init=0.02, relative=False): - super().__init__() - # nn.Embedding - self.emb = torch.nn.Embedding(seq_len, model_dim) - # Initializing this way is standard for GPT-2 - self.emb.weight.data.normal_(mean=0.0, std=init) - self.relative = relative - self.seq_len = seq_len - - def forward(self, x): - sl = x.shape[1] - if self.relative: - start = random.randint(sl, self.seq_len) - sl - return self.emb(torch.arange(start, start + sl, device=x.device)) - else: - return self.emb(torch.arange(0, sl, device=x.device)) - - def get_fixed_embedding(self, ind, dev): - return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) - - def build_hf_gpt_transformer( layers, model_dim, From 33ac0d6ee179b9959d86130bdfbff2abad30c587 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 21 Nov 2024 15:33:36 +0100 Subject: [PATCH 08/25] refactor(xtts): use build_hf_gpt_transformer from tortoise --- TTS/tts/layers/tortoise/autoregressive.py | 43 +++++++++----- TTS/tts/layers/xtts/gpt.py | 70 ++++------------------- 2 files changed, 40 insertions(+), 73 deletions(-) diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 3463e63b39..19c1adc0a6 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -217,7 +217,15 @@ def get_fixed_embedding(self, ind, dev): return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0) -def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing): +def build_hf_gpt_transformer( + layers: int, + model_dim: int, + heads: int, + max_mel_seq_len: int, + max_text_seq_len: int, + checkpointing: bool, + max_prompt_len: int = 0, +): """ GPT-2 implemented by the HuggingFace library. """ @@ -225,8 +233,8 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text gpt_config = GPT2Config( vocab_size=256, # Unused. - n_positions=max_mel_seq_len + max_text_seq_len, - n_ctx=max_mel_seq_len + max_text_seq_len, + n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, + n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, n_embd=model_dim, n_layer=layers, n_head=heads, @@ -239,13 +247,18 @@ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) # Built-in token embeddings are unused. del gpt.wte - return ( - gpt, - LearnedPositionEmbeddings(max_mel_seq_len, model_dim), - LearnedPositionEmbeddings(max_text_seq_len, model_dim), - None, - None, + + mel_pos_emb = ( + LearnedPositionEmbeddings(max_mel_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) + ) + text_pos_emb = ( + LearnedPositionEmbeddings(max_text_seq_len, model_dim) + if max_mel_seq_len != -1 + else functools.partial(null_position_embeddings, dim=model_dim) ) + return gpt, mel_pos_emb, text_pos_emb, None, None class MelEncoder(nn.Module): @@ -339,12 +352,12 @@ def __init__( self.mel_layer_pos_embedding, self.text_layer_pos_embedding, ) = build_hf_gpt_transformer( - layers, - model_dim, - heads, - self.max_mel_tokens + 2 + self.max_conditioning_inputs, - self.max_text_tokens + 2, - checkpointing, + layers=layers, + model_dim=model_dim, + heads=heads, + max_mel_seq_len=self.max_mel_tokens + 2 + self.max_conditioning_inputs, + max_text_seq_len=self.max_text_tokens + 2, + checkpointing=checkpointing, ) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index f93287619e..899522e091 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -1,6 +1,5 @@ # ported from: https://github.com/neonbjb/tortoise-tts -import functools import random import torch @@ -8,61 +7,16 @@ import torch.nn.functional as F from transformers import GPT2Config -from TTS.tts.layers.tortoise.autoregressive import LearnedPositionEmbeddings, _prepare_attention_mask_for_generation +from TTS.tts.layers.tortoise.autoregressive import ( + LearnedPositionEmbeddings, + _prepare_attention_mask_for_generation, + build_hf_gpt_transformer, +) from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler -def null_position_embeddings(range, dim): - return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device) - - -def build_hf_gpt_transformer( - layers, - model_dim, - heads, - max_mel_seq_len, - max_text_seq_len, - max_prompt_len, - checkpointing, -): - """ - GPT-2 implemented by the HuggingFace library. - """ - from transformers import GPT2Config, GPT2Model - - gpt_config = GPT2Config( - vocab_size=256, # Unused. - n_positions=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_ctx=max_mel_seq_len + max_text_seq_len + max_prompt_len, - n_embd=model_dim, - n_layer=layers, - n_head=heads, - gradient_checkpointing=checkpointing, - use_cache=not checkpointing, - ) - gpt = GPT2Model(gpt_config) - # Override the built in positional embeddings - del gpt.wpe - gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim) - # Built-in token embeddings are unused. - del gpt.wte - - mel_pos_emb = ( - LearnedPositionEmbeddings(max_mel_seq_len, model_dim) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_dim) - ) - text_pos_emb = ( - LearnedPositionEmbeddings(max_text_seq_len, model_dim) - if max_mel_seq_len != -1 - else functools.partial(null_position_embeddings, dim=model_dim) - ) - # gpt = torch.compile(gpt, mode="reduce-overhead", fullgraph=True) - return gpt, mel_pos_emb, text_pos_emb, None, None - - class GPT(nn.Module): def __init__( self, @@ -127,13 +81,13 @@ def __init__( self.mel_layer_pos_embedding, self.text_layer_pos_embedding, ) = build_hf_gpt_transformer( - layers, - model_dim, - heads, - self.max_mel_tokens, - self.max_text_tokens, - self.max_prompt_tokens, - checkpointing, + layers=layers, + model_dim=model_dim, + heads=heads, + max_mel_seq_len=self.max_mel_tokens, + max_text_seq_len=self.max_text_tokens, + max_prompt_len=self.max_prompt_tokens, + checkpointing=checkpointing, ) if train_solo_embeddings: self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * 0.02, requires_grad=True) From 7cdfde226bc03cc792424c4f3a93741150213cfc Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 21:30:21 +0100 Subject: [PATCH 09/25] refactor: move amp_to_db/db_to_amp into torch_transforms --- TTS/tts/layers/tortoise/audio_utils.py | 22 ++------------- TTS/tts/models/delightful_tts.py | 19 +------------ TTS/tts/models/vits.py | 19 +------------ TTS/utils/audio/numpy_transforms.py | 2 +- TTS/utils/audio/torch_transforms.py | 18 ++++++------ TTS/vc/modules/freevc/mel_processing.py | 35 +++--------------------- tests/aux_tests/test_stft_torch.py | 0 tests/aux_tests/test_torch_transforms.py | 16 +++++++++++ tests/tts_tests/test_vits.py | 3 +- 9 files changed, 36 insertions(+), 98 deletions(-) delete mode 100644 tests/aux_tests/test_stft_torch.py create mode 100644 tests/aux_tests/test_torch_transforms.py diff --git a/TTS/tts/layers/tortoise/audio_utils.py b/TTS/tts/layers/tortoise/audio_utils.py index 4f299a8fd9..c67ee6c44b 100644 --- a/TTS/tts/layers/tortoise/audio_utils.py +++ b/TTS/tts/layers/tortoise/audio_utils.py @@ -9,7 +9,7 @@ import torchaudio from scipy.io.wavfile import read -from TTS.utils.audio.torch_transforms import TorchSTFT +from TTS.utils.audio.torch_transforms import TorchSTFT, amp_to_db from TTS.utils.generic_utils import is_pytorch_at_least_2_4 logger = logging.getLogger(__name__) @@ -88,24 +88,6 @@ def normalize_tacotron_mel(mel): return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 -def dynamic_range_compression(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - def get_voices(extra_voice_dirs: List[str] = []): dirs = extra_voice_dirs voices: Dict[str, List[str]] = {} @@ -175,7 +157,7 @@ def wav_to_univnet_mel(wav, do_normalization=False, device="cuda"): ) stft = stft.to(device) mel = stft(wav) - mel = dynamic_range_compression(mel) + mel = amp_to_db(mel) if do_normalization: mel = normalize_tacotron_mel(mel) return mel diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index c6f15a7952..880ea4ae26 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -32,6 +32,7 @@ from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.processor import AudioProcessor +from TTS.utils.audio.torch_transforms import amp_to_db from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -136,24 +137,6 @@ def load_audio(file_path: str): return x, sr -def _amp_to_db(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def _db_to_amp(x, C=1): - return torch.exp(x) / C - - -def amp_to_db(magnitudes): - output = _amp_to_db(magnitudes) - return output - - -def db_to_amp(magnitudes): - output = _db_to_amp(magnitudes) - return output - - def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 432b29f5e1..aea0f4e4f8 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -35,6 +35,7 @@ from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment +from TTS.utils.audio.torch_transforms import amp_to_db from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -78,24 +79,6 @@ def load_audio(file_path): return x, sr -def _amp_to_db(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def _db_to_amp(x, C=1): - return torch.exp(x) / C - - -def amp_to_db(magnitudes): - output = _amp_to_db(magnitudes) - return output - - -def db_to_amp(magnitudes): - output = _db_to_amp(magnitudes) - return output - - def wav_to_spec(y, n_fft, hop_length, win_length, center=False): """ Args Shapes: diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index 203091ea88..9c83009b0f 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -59,7 +59,7 @@ def _exp(x, base): return np.exp(x) -def amp_to_db(*, x: np.ndarray, gain: float = 1, base: int = 10, **kwargs) -> np.ndarray: +def amp_to_db(*, x: np.ndarray, gain: float = 1, base: float = 10, **kwargs) -> np.ndarray: """Convert amplitude values to decibels. Args: diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index 632969c51a..dda4c0a419 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -3,6 +3,16 @@ from torch import nn +def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor: + """Spectral normalization / dynamic range compression.""" + return torch.log(torch.clamp(x, min=clip_val) * spec_gain) + + +def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor: + """Spectral denormalization / dynamic range decompression.""" + return torch.exp(x) / spec_gain + + class TorchSTFT(nn.Module): # pylint: disable=abstract-method """Some of the audio processing funtions using Torch for faster batch processing. @@ -157,11 +167,3 @@ def _build_mel_basis(self): norm=self.mel_norm, ) self.mel_basis = torch.from_numpy(mel_basis).float() - - @staticmethod - def _amp_to_db(x, spec_gain=1.0): - return torch.log(torch.clamp(x, min=1e-5) * spec_gain) - - @staticmethod - def _db_to_amp(x, spec_gain=1.0): - return torch.exp(x) / spec_gain diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index a3e251891a..4da5e27c83 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -4,39 +4,12 @@ import torch.utils.data from librosa.filters import mel as librosa_mel_fn +from TTS.utils.audio.torch_transforms import amp_to_db + logger = logging.getLogger(__name__) MAX_WAV_VALUE = 32768.0 - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - """ - PARAMS - ------ - C: compression factor - """ - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - """ - PARAMS - ------ - C: compression factor used to compress - """ - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - mel_basis = {} hann_window = {} @@ -85,7 +58,7 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) + spec = amp_to_db(spec) return spec @@ -128,6 +101,6 @@ def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = spectral_normalize_torch(spec) + spec = amp_to_db(spec) return spec diff --git a/tests/aux_tests/test_stft_torch.py b/tests/aux_tests/test_stft_torch.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/aux_tests/test_torch_transforms.py b/tests/aux_tests/test_torch_transforms.py new file mode 100644 index 0000000000..2da5a359c1 --- /dev/null +++ b/tests/aux_tests/test_torch_transforms.py @@ -0,0 +1,16 @@ +import numpy as np +import torch + +from TTS.utils.audio import numpy_transforms as np_transforms +from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp + + +def test_amplitude_db_conversion(): + x = torch.rand(11) + o1 = amp_to_db(x=x, spec_gain=1.0) + o2 = db_to_amp(x=o1, spec_gain=1.0) + np_o1 = np_transforms.amp_to_db(x=x, base=np.e) + np_o2 = np_transforms.db_to_amp(x=np_o1, base=np.e) + assert torch.allclose(x, o2) + assert torch.allclose(o1, np_o1) + assert torch.allclose(o2, np_o2) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index 17992773ad..a27bdfe5b5 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -13,14 +13,13 @@ Vits, VitsArgs, VitsAudioConfig, - amp_to_db, - db_to_amp, load_audio, spec_to_mel, wav_to_mel, wav_to_spec, ) from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") From 6f25c2b90463dec6afe8c9c788a0a3a717030429 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 00:38:37 +0100 Subject: [PATCH 10/25] refactor(delightful_tts): remove unused classes --- TTS/tts/layers/delightful_tts/conformer.py | 49 +----- TTS/tts/layers/delightful_tts/conv_layers.py | 142 ------------------ .../layers/delightful_tts/kernel_predictor.py | 128 ---------------- TTS/tts/layers/tacotron/gst_layers.py | 10 +- TTS/tts/models/delightful_tts.py | 113 ++------------ TTS/tts/utils/synthesis.py | 19 +-- 6 files changed, 24 insertions(+), 437 deletions(-) delete mode 100644 TTS/tts/layers/delightful_tts/kernel_predictor.py diff --git a/TTS/tts/layers/delightful_tts/conformer.py b/TTS/tts/layers/delightful_tts/conformer.py index b2175b3b96..227a871c69 100644 --- a/TTS/tts/layers/delightful_tts/conformer.py +++ b/TTS/tts/layers/delightful_tts/conformer.py @@ -1,20 +1,14 @@ ### credit: https://github.com/dunky11/voicesmith import math -from typing import Tuple import torch import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn.functional as F -from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d +from TTS.tts.layers.delightful_tts.conv_layers import Conv1dGLU, DepthWiseConv1d, PointwiseConv1d, calc_same_padding from TTS.tts.layers.delightful_tts.networks import GLUActivation -def calc_same_padding(kernel_size: int) -> Tuple[int, int]: - pad = kernel_size // 2 - return (pad, pad - (kernel_size + 1) % 2) - - class Conformer(nn.Module): def __init__( self, @@ -322,7 +316,7 @@ def forward( value: torch.Tensor, mask: torch.Tensor, encoding: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: batch_size, seq_length, _ = key.size() # pylint: disable=unused-variable encoding = encoding[:, : key.shape[1]] encoding = encoding.repeat(batch_size, 1, 1) @@ -378,7 +372,7 @@ def forward( value: torch.Tensor, pos_embedding: torch.Tensor, mask: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: batch_size = query.shape[0] query = self.query_proj(query).view(batch_size, -1, self.num_heads, self.d_head) key = self.key_proj(key).view(batch_size, -1, self.num_heads, self.d_head).permute(0, 2, 1, 3) @@ -411,40 +405,3 @@ def _relative_shift(self, pos_score: torch.Tensor) -> torch.Tensor: # pylint: d padded_pos_score = padded_pos_score.view(batch_size, num_heads, seq_length2 + 1, seq_length1) pos_score = padded_pos_score[:, :, 1:].view_as(pos_score) return pos_score - - -class MultiHeadAttention(nn.Module): - """ - input: - query --- [N, T_q, query_dim] - key --- [N, T_k, key_dim] - output: - out --- [N, T_q, num_units] - """ - - def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): - super().__init__() - self.num_units = num_units - self.num_heads = num_heads - self.key_dim = key_dim - - self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) - self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) - self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) - - def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor: - querys = self.W_query(query) # [N, T_q, num_units] - keys = self.W_key(key) # [N, T_k, num_units] - values = self.W_value(key) - split_size = self.num_units // self.num_heads - querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] - keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] - values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] - # score = softmax(QK^T / (d_k ** 0.5)) - scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k] - scores = scores / (self.key_dim**0.5) - scores = F.softmax(scores, dim=3) - # out = score * V - out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] - out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] - return out diff --git a/TTS/tts/layers/delightful_tts/conv_layers.py b/TTS/tts/layers/delightful_tts/conv_layers.py index fb9aa4495f..1d5139571e 100644 --- a/TTS/tts/layers/delightful_tts/conv_layers.py +++ b/TTS/tts/layers/delightful_tts/conv_layers.py @@ -3,9 +3,6 @@ import torch import torch.nn as nn # pylint: disable=consider-using-from-import import torch.nn.functional as F -from torch.nn.utils import parametrize - -from TTS.tts.layers.delightful_tts.kernel_predictor import KernelPredictor def calc_same_padding(kernel_size: int) -> Tuple[int, int]: @@ -530,142 +527,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.addcoords(x) x = self.conv(x) return x - - -class LVCBlock(torch.nn.Module): - """the location-variable convolutions""" - - def __init__( # pylint: disable=dangerous-default-value - self, - in_channels, - cond_channels, - stride, - dilations=[1, 3, 9, 27], - lReLU_slope=0.2, - conv_kernel_size=3, - cond_hop_length=256, - kpnet_hidden_channels=64, - kpnet_conv_size=3, - kpnet_dropout=0.0, - ): - super().__init__() - - self.cond_hop_length = cond_hop_length - self.conv_layers = len(dilations) - self.conv_kernel_size = conv_kernel_size - - self.kernel_predictor = KernelPredictor( - cond_channels=cond_channels, - conv_in_channels=in_channels, - conv_out_channels=2 * in_channels, - conv_layers=len(dilations), - conv_kernel_size=conv_kernel_size, - kpnet_hidden_channels=kpnet_hidden_channels, - kpnet_conv_size=kpnet_conv_size, - kpnet_dropout=kpnet_dropout, - kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, - ) - - self.convt_pre = nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.parametrizations.weight_norm( - nn.ConvTranspose1d( - in_channels, - in_channels, - 2 * stride, - stride=stride, - padding=stride // 2 + stride % 2, - output_padding=stride % 2, - ) - ), - ) - - self.conv_blocks = nn.ModuleList() - for dilation in dilations: - self.conv_blocks.append( - nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.parametrizations.weight_norm( - nn.Conv1d( - in_channels, - in_channels, - conv_kernel_size, - padding=dilation * (conv_kernel_size - 1) // 2, - dilation=dilation, - ) - ), - nn.LeakyReLU(lReLU_slope), - ) - ) - - def forward(self, x, c): - """forward propagation of the location-variable convolutions. - Args: - x (Tensor): the input sequence (batch, in_channels, in_length) - c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) - - Returns: - Tensor: the output sequence (batch, in_channels, in_length) - """ - _, in_channels, _ = x.shape # (B, c_g, L') - - x = self.convt_pre(x) # (B, c_g, stride * L') - kernels, bias = self.kernel_predictor(c) - - for i, conv in enumerate(self.conv_blocks): - output = conv(x) # (B, c_g, stride * L') - - k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) - b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) - - output = self.location_variable_convolution( - output, k, b, hop_size=self.cond_hop_length - ) # (B, 2 * c_g, stride * L'): LVC - x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( - output[:, in_channels:, :] - ) # (B, c_g, stride * L'): GAU - - return x - - def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): # pylint: disable=no-self-use - """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. - Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. - Args: - x (Tensor): the input sequence (batch, in_channels, in_length). - kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) - bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) - dilation (int): the dilation of convolution. - hop_size (int): the hop_size of the conditioning sequence. - Returns: - (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). - """ - batch, _, in_length = x.shape - batch, _, out_channels, kernel_size, kernel_length = kernel.shape - assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" - - padding = dilation * int((kernel_size - 1) / 2) - x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) - x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) - - if hop_size < dilation: - x = F.pad(x, (0, dilation), "constant", 0) - x = x.unfold( - 3, dilation, dilation - ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) - x = x[:, :, :, :, :hop_size] - x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) - x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) - - o = torch.einsum("bildsk,biokl->bolsd", x, kernel) - o = o.to(memory_format=torch.channels_last_3d) - bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) - o = o + bias - o = o.contiguous().view(batch, out_channels, -1) - - return o - - def remove_weight_norm(self): - self.kernel_predictor.remove_weight_norm() - parametrize.remove_parametrizations(self.convt_pre[1], "weight") - for block in self.conv_blocks: - parametrize.remove_parametrizations(block[1], "weight") diff --git a/TTS/tts/layers/delightful_tts/kernel_predictor.py b/TTS/tts/layers/delightful_tts/kernel_predictor.py deleted file mode 100644 index 96c550b6c2..0000000000 --- a/TTS/tts/layers/delightful_tts/kernel_predictor.py +++ /dev/null @@ -1,128 +0,0 @@ -import torch.nn as nn # pylint: disable=consider-using-from-import -from torch.nn.utils import parametrize - - -class KernelPredictor(nn.Module): - """Kernel predictor for the location-variable convolutions - - Args: - cond_channels (int): number of channel for the conditioning sequence, - conv_in_channels (int): number of channel for the input sequence, - conv_out_channels (int): number of channel for the output sequence, - conv_layers (int): number of layers - - """ - - def __init__( # pylint: disable=dangerous-default-value - self, - cond_channels, - conv_in_channels, - conv_out_channels, - conv_layers, - conv_kernel_size=3, - kpnet_hidden_channels=64, - kpnet_conv_size=3, - kpnet_dropout=0.0, - kpnet_nonlinear_activation="LeakyReLU", - kpnet_nonlinear_activation_params={"negative_slope": 0.1}, - ): - super().__init__() - - self.conv_in_channels = conv_in_channels - self.conv_out_channels = conv_out_channels - self.conv_kernel_size = conv_kernel_size - self.conv_layers = conv_layers - - kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w - kpnet_bias_channels = conv_out_channels * conv_layers # l_b - - self.input_conv = nn.Sequential( - nn.utils.parametrizations.weight_norm( - nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - ) - - self.residual_convs = nn.ModuleList() - padding = (kpnet_conv_size - 1) // 2 - for _ in range(3): - self.residual_convs.append( - nn.Sequential( - nn.Dropout(kpnet_dropout), - nn.utils.parametrizations.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_hidden_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - nn.utils.parametrizations.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_hidden_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - ) - ) - self.kernel_conv = nn.utils.parametrizations.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_kernel_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ) - self.bias_conv = nn.utils.parametrizations.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_bias_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ) - - def forward(self, c): - """ - Args: - c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) - """ - batch, _, cond_length = c.shape - c = self.input_conv(c) - for residual_conv in self.residual_convs: - residual_conv.to(c.device) - c = c + residual_conv(c) - k = self.kernel_conv(c) - b = self.bias_conv(c) - kernels = k.contiguous().view( - batch, - self.conv_layers, - self.conv_in_channels, - self.conv_out_channels, - self.conv_kernel_size, - cond_length, - ) - bias = b.contiguous().view( - batch, - self.conv_layers, - self.conv_out_channels, - cond_length, - ) - - return kernels, bias - - def remove_weight_norm(self): - parametrize.remove_parametrizations(self.input_conv[0], "weight") - parametrize.remove_parametrizations(self.kernel_conv, "weight") - parametrize.remove_parametrizations(self.bias_conv, "weight") - for block in self.residual_convs: - parametrize.remove_parametrizations(block[1], "weight") - parametrize.remove_parametrizations(block[3], "weight") diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 05dba7084f..ac3d7d4aae 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -117,7 +117,7 @@ class MultiHeadAttention(nn.Module): out --- [N, T_q, num_units] """ - def __init__(self, query_dim, key_dim, num_units, num_heads): + def __init__(self, query_dim: int, key_dim: int, num_units: int, num_heads: int): super().__init__() self.num_units = num_units self.num_heads = num_heads @@ -127,7 +127,7 @@ def __init__(self, query_dim, key_dim, num_units, num_heads): self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) - def forward(self, query, key): + def forward(self, query: torch.Tensor, key: torch.Tensor) -> torch.Tensor: queries = self.W_query(query) # [N, T_q, num_units] keys = self.W_key(key) # [N, T_k, num_units] values = self.W_value(key) @@ -137,13 +137,11 @@ def forward(self, query, key): keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] - # score = softmax(QK^T / (d_k**0.5)) + # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] scores = scores / (self.key_dim**0.5) scores = F.softmax(scores, dim=3) # out = score * V out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] - out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] - - return out + return torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 880ea4ae26..2f34e4323b 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -8,11 +8,9 @@ import numpy as np import torch import torch.distributed as dist -import torchaudio from coqpit import Coqpit from librosa.filters import mel as librosa_mel_fn from torch import nn -from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from trainer.io import load_fsspec @@ -24,8 +22,10 @@ from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.models.base_tts import BaseTTSE2E +from TTS.tts.models.vits import load_audio from TTS.tts.utils.helpers import average_over_durations, compute_attn_prior, rand_segments, segment, sequence_mask from TTS.tts.utils.speakers import SpeakerManager +from TTS.tts.utils.synthesis import embedding_to_torch, id_to_torch, numpy_to_torch from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_pitch, plot_pitch, plot_spectrogram from TTS.utils.audio.numpy_transforms import build_mel_basis, compute_f0 @@ -40,103 +40,10 @@ logger = logging.getLogger(__name__) -def id_to_torch(aux_id, cuda=False): - if aux_id is not None: - aux_id = np.asarray(aux_id) - aux_id = torch.from_numpy(aux_id) - if cuda: - return aux_id.cuda() - return aux_id - - -def embedding_to_torch(d_vector, cuda=False): - if d_vector is not None: - d_vector = np.asarray(d_vector) - d_vector = torch.from_numpy(d_vector).float() - d_vector = d_vector.squeeze().unsqueeze(0) - if cuda: - return d_vector.cuda() - return d_vector - - -def numpy_to_torch(np_array, dtype, cuda=False): - if np_array is None: - return None - tensor = torch.as_tensor(np_array, dtype=dtype) - if cuda: - return tensor.cuda() - return tensor - - -def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: - batch_size = lengths.shape[0] - max_len = torch.max(lengths).item() - ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) - mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) - return mask - - -def pad(input_ele: List[torch.Tensor], max_len: int) -> torch.Tensor: - out_list = torch.jit.annotate(List[torch.Tensor], []) - for batch in input_ele: - if len(batch.shape) == 1: - one_batch_padded = F.pad(batch, (0, max_len - batch.size(0)), "constant", 0.0) - else: - one_batch_padded = F.pad(batch, (0, 0, 0, max_len - batch.size(0)), "constant", 0.0) - out_list.append(one_batch_padded) - out_padded = torch.stack(out_list) - return out_padded - - -def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: - return torch.ceil(lens / stride).int() - - -def initialize_embeddings(shape: Tuple[int]) -> torch.Tensor: - assert len(shape) == 2, "Can only initialize 2-D embedding matrices ..." - return torch.randn(shape) * np.sqrt(2 / shape[1]) - - -# pylint: disable=redefined-outer-name -def calc_same_padding(kernel_size: int) -> Tuple[int, int]: - pad = kernel_size // 2 - return (pad, pad - (kernel_size + 1) % 2) - - hann_window = {} mel_basis = {} -@torch.no_grad() -def weights_reset(m: nn.Module): - # check if the current module has reset_parameters and if it is reset the weight - reset_parameters = getattr(m, "reset_parameters", None) - if callable(reset_parameters): - m.reset_parameters() - - -def get_module_weights_sum(mdl: nn.Module): - dict_sums = {} - for name, w in mdl.named_parameters(): - if "weight" in name: - value = w.data.sum().item() - dict_sums[name] = value - return dict_sums - - -def load_audio(file_path: str): - """Load the audio file normalized in [-1, 1] - - Return Shapes: - - x: :math:`[1, T]` - """ - x, sr = torchaudio.load( - file_path, - ) - assert (x > 1).sum() + (x < -1).sum() == 0 - return x, sr - - def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): y = y.squeeze(1) @@ -1179,7 +1086,7 @@ def synthesize( **kwargs, ): # pylint: disable=unused-argument # TODO: add cloning support with ref_waveform - is_cuda = next(self.parameters()).is_cuda + device = next(self.parameters()).device # convert text to sequence of token IDs text_inputs = np.asarray( @@ -1193,14 +1100,14 @@ def synthesize( if isinstance(speaker_id, str) and self.args.use_speaker_embedding: # get the speaker id for the speaker embedding layer _speaker_id = self.speaker_manager.name_to_id[speaker_id] - _speaker_id = id_to_torch(_speaker_id, cuda=is_cuda) + _speaker_id = id_to_torch(_speaker_id, device=device) if speaker_id is not None and self.args.use_d_vector_file: # get the average d_vector for the speaker d_vector = self.speaker_manager.get_mean_embedding(speaker_id, num_samples=None, randomize=False) - d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + d_vector = embedding_to_torch(d_vector, device=device) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) text_inputs = text_inputs.unsqueeze(0) # synthesize voice @@ -1223,7 +1130,7 @@ def synthesize( return return_dict def synthesize_with_gl(self, text: str, speaker_id, d_vector): - is_cuda = next(self.parameters()).is_cuda + device = next(self.parameters()).device # convert text to sequence of token IDs text_inputs = np.asarray( @@ -1232,12 +1139,12 @@ def synthesize_with_gl(self, text: str, speaker_id, d_vector): ) # pass tensors to backend if speaker_id is not None: - speaker_id = id_to_torch(speaker_id, cuda=is_cuda) + speaker_id = id_to_torch(speaker_id, device=device) if d_vector is not None: - d_vector = embedding_to_torch(d_vector, cuda=is_cuda) + d_vector = embedding_to_torch(d_vector, device=device) - text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=is_cuda) + text_inputs = numpy_to_torch(text_inputs, torch.long, device=device) text_inputs = text_inputs.unsqueeze(0) # synthesize voice diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 797151c254..5dc4cc569f 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,17 +1,16 @@ -from typing import Dict +from typing import Dict, Optional, Union import numpy as np import torch from torch import nn -def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"): - if cuda: - device = "cuda" +def numpy_to_torch( + np_array: np.ndarray, dtype: torch.dtype, device: Union[str, torch.device] = "cpu" +) -> Optional[torch.Tensor]: if np_array is None: return None - tensor = torch.as_tensor(np_array, dtype=dtype, device=device) - return tensor + return torch.as_tensor(np_array, dtype=dtype, device=device) def compute_style_mel(style_wav, ap, cuda=False, device="cpu"): @@ -76,18 +75,14 @@ def inv_spectrogram(postnet_output, ap, CONFIG): return wav -def id_to_torch(aux_id, cuda=False, device="cpu"): - if cuda: - device = "cuda" +def id_to_torch(aux_id, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]: if aux_id is not None: aux_id = np.asarray(aux_id) aux_id = torch.from_numpy(aux_id).to(device) return aux_id -def embedding_to_torch(d_vector, cuda=False, device="cpu"): - if cuda: - device = "cuda" +def embedding_to_torch(d_vector, device: Union[str, torch.device] = "cpu") -> Optional[torch.Tensor]: if d_vector is not None: d_vector = np.asarray(d_vector) d_vector = torch.from_numpy(d_vector).type(torch.FloatTensor) From e63962c22662d76c0765fdb35fd0b30fce8888c8 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 00:45:33 +0100 Subject: [PATCH 11/25] refactor(losses): move shared losses into losses.py --- TTS/tts/layers/losses.py | 87 +++++++++++++++++++------------- TTS/tts/models/delightful_tts.py | 44 ++++------------ TTS/tts/models/neuralhmm_tts.py | 19 +------ TTS/tts/models/overflow.py | 19 +------ 4 files changed, 64 insertions(+), 105 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 5ebed81dda..db62430c9d 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -309,6 +309,24 @@ def forward(self, attn_logprob, in_lens, out_lens): return total_loss +class NLLLoss(nn.Module): + """Negative log likelihood loss.""" + + def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use + """Compute the loss. + + Args: + logits (Tensor): [B, T, D] + + Returns: + Tensor: [1] + + """ + return_dict = {} + return_dict["loss"] = -log_prob.mean() + return return_dict + + ######################## # MODEL LOSS LAYERS ######################## @@ -619,6 +637,28 @@ def forward( return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} +def feature_loss(feats_real, feats_generated): + loss = 0 + for dr, dg in zip(feats_real, feats_generated): + for rl, gl in zip(dr, dg): + rl = rl.float().detach() + gl = gl.float() + loss += torch.mean(torch.abs(rl - gl)) + return loss * 2 + + +def generator_loss(scores_fake): + loss = 0 + gen_losses = [] + for dg in scores_fake: + dg = dg.float() + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + + return loss, gen_losses + + class VitsGeneratorLoss(nn.Module): def __init__(self, c: Coqpit): super().__init__() @@ -640,28 +680,6 @@ def __init__(self, c: Coqpit): do_amp_to_db=True, ) - @staticmethod - def feature_loss(feats_real, feats_generated): - loss = 0 - for dr, dg in zip(feats_real, feats_generated): - for rl, gl in zip(dr, dg): - rl = rl.float().detach() - gl = gl.float() - loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 - - @staticmethod - def generator_loss(scores_fake): - loss = 0 - gen_losses = [] - for dg in scores_fake: - dg = dg.float() - l = torch.mean((1 - dg) ** 2) - gen_losses.append(l) - loss += l - - return loss, gen_losses - @staticmethod def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): """ @@ -722,10 +740,8 @@ def forward( self.kl_loss(z_p=z_p, logs_q=logs_q, m_p=m_p, logs_p=logs_p, z_mask=z_mask.unsqueeze(1)) * self.kl_loss_alpha ) - loss_feat = ( - self.feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha - ) - loss_gen = self.generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha + loss_feat = feature_loss(feats_real=feats_disc_real, feats_generated=feats_disc_fake) * self.feat_loss_alpha + loss_gen = generator_loss(scores_fake=scores_disc_fake)[0] * self.gen_loss_alpha loss_mel = torch.nn.functional.l1_loss(mel_slice, mel_slice_hat) * self.mel_loss_alpha loss_duration = torch.sum(loss_duration.float()) * self.dur_loss_alpha loss = loss_kl + loss_feat + loss_mel + loss_gen + loss_duration @@ -779,6 +795,15 @@ def forward(self, scores_disc_real, scores_disc_fake): return return_dict +def _binary_alignment_loss(alignment_hard, alignment_soft): + """Binary loss that forces soft alignments to match the hard alignments. + + Explained in `https://arxiv.org/pdf/2108.10447.pdf`. + """ + log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() + return -log_sum / alignment_hard.sum() + + class ForwardTTSLoss(nn.Module): """Generic configurable ForwardTTS loss.""" @@ -820,14 +845,6 @@ def __init__(self, c): self.dur_loss_alpha = c.dur_loss_alpha self.binary_alignment_loss_alpha = c.binary_align_loss_alpha - @staticmethod - def _binary_alignment_loss(alignment_hard, alignment_soft): - """Binary loss that forces soft alignments to match the hard alignments as - explained in `https://arxiv.org/pdf/2108.10447.pdf`. - """ - log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() - return -log_sum / alignment_hard.sum() - def forward( self, decoder_output, @@ -879,7 +896,7 @@ def forward( return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: - binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) + binary_alignment_loss = _binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss if binary_loss_weight: return_dict["loss_binary_alignment"] = ( diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 2f34e4323b..7216e8143a 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -19,7 +19,13 @@ from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel -from TTS.tts.layers.losses import ForwardSumLoss, VitsDiscriminatorLoss +from TTS.tts.layers.losses import ( + ForwardSumLoss, + VitsDiscriminatorLoss, + _binary_alignment_loss, + feature_loss, + generator_loss, +) from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.models.base_tts import BaseTTSE2E from TTS.tts.models.vits import load_audio @@ -1491,36 +1497,6 @@ def __init__(self, config): self.gen_loss_alpha = config.gen_loss_alpha self.multi_scale_stft_loss_alpha = config.multi_scale_stft_loss_alpha - @staticmethod - def _binary_alignment_loss(alignment_hard, alignment_soft): - """Binary loss that forces soft alignments to match the hard alignments as - explained in `https://arxiv.org/pdf/2108.10447.pdf`. - """ - log_sum = torch.log(torch.clamp(alignment_soft[alignment_hard == 1], min=1e-12)).sum() - return -log_sum / alignment_hard.sum() - - @staticmethod - def feature_loss(feats_real, feats_generated): - loss = 0 - for dr, dg in zip(feats_real, feats_generated): - for rl, gl in zip(dr, dg): - rl = rl.float().detach() - gl = gl.float() - loss += torch.mean(torch.abs(rl - gl)) - return loss * 2 - - @staticmethod - def generator_loss(scores_fake): - loss = 0 - gen_losses = [] - for dg in scores_fake: - dg = dg.float() - l = torch.mean((1 - dg) ** 2) - gen_losses.append(l) - loss += l - - return loss, gen_losses - def forward( self, mel_output, @@ -1618,7 +1594,7 @@ def forward( ) if self.binary_alignment_loss_alpha > 0 and aligner_hard is not None: - binary_alignment_loss = self._binary_alignment_loss(aligner_hard, aligner_soft) + binary_alignment_loss = _binary_alignment_loss(aligner_hard, aligner_soft) total_loss = total_loss + self.binary_alignment_loss_alpha * binary_alignment_loss * binary_loss_weight if binary_loss_weight: loss_dict["loss_binary_alignment"] = ( @@ -1638,8 +1614,8 @@ def forward( # vocoder losses if not skip_disc: - loss_feat = self.feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha - loss_gen = self.generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha + loss_feat = feature_loss(feats_real=feats_real, feats_generated=feats_fake) * self.feat_loss_alpha + loss_gen = generator_loss(scores_fake=scores_fake)[0] * self.gen_loss_alpha loss_dict["vocoder_loss_feat"] = loss_feat loss_dict["vocoder_loss_gen"] = loss_gen loss_dict["loss"] = loss_dict["loss"] + loss_feat + loss_gen diff --git a/TTS/tts/models/neuralhmm_tts.py b/TTS/tts/models/neuralhmm_tts.py index de5401aac7..0b3fadafbf 100644 --- a/TTS/tts/models/neuralhmm_tts.py +++ b/TTS/tts/models/neuralhmm_tts.py @@ -8,6 +8,7 @@ from trainer.io import load_fsspec from trainer.logging.tensorboard_logger import TensorboardLogger +from TTS.tts.layers.losses import NLLLoss from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils from TTS.tts.layers.overflow.neural_hmm import NeuralHMM from TTS.tts.layers.overflow.plotting_utils import ( @@ -373,21 +374,3 @@ def test_log( ) -> None: logger.test_audios(steps, outputs[1], self.ap.sample_rate) logger.test_figures(steps, outputs[0]) - - -class NLLLoss(nn.Module): - """Negative log likelihood loss.""" - - def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use - """Compute the loss. - - Args: - logits (Tensor): [B, T, D] - - Returns: - Tensor: [1] - - """ - return_dict = {} - return_dict["loss"] = -log_prob.mean() - return return_dict diff --git a/TTS/tts/models/overflow.py b/TTS/tts/models/overflow.py index b72f4877cf..ac09e406ad 100644 --- a/TTS/tts/models/overflow.py +++ b/TTS/tts/models/overflow.py @@ -8,6 +8,7 @@ from trainer.io import load_fsspec from trainer.logging.tensorboard_logger import TensorboardLogger +from TTS.tts.layers.losses import NLLLoss from TTS.tts.layers.overflow.common_layers import Encoder, OverflowUtils from TTS.tts.layers.overflow.decoder import Decoder from TTS.tts.layers.overflow.neural_hmm import NeuralHMM @@ -389,21 +390,3 @@ def test_log( ) -> None: logger.test_audios(steps, outputs[1], self.ap.sample_rate) logger.test_figures(steps, outputs[0]) - - -class NLLLoss(nn.Module): - """Negative log likelihood loss.""" - - def forward(self, log_prob: torch.Tensor) -> dict: # pylint: disable=no-self-use - """Compute the loss. - - Args: - logits (Tensor): [B, T, D] - - Returns: - Tensor: [1] - - """ - return_dict = {} - return_dict["loss"] = -log_prob.mean() - return return_dict From 2e5f68df6a72cca1f25867cca3d38f585be14df6 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 01:16:42 +0100 Subject: [PATCH 12/25] refactor(wavernn): remove duplicate Stretch2d I checked that the implementations are the same --- TTS/vocoder/models/wavernn.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 723f18dde2..1847679890 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -17,6 +17,7 @@ from TTS.utils.audio.numpy_transforms import mulaw_decode from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.layers.losses import WaveRNNLoss +from TTS.vocoder.layers.upsample import Stretch2d from TTS.vocoder.models.base_vocoder import BaseVocoder from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian @@ -66,19 +67,6 @@ def forward(self, x): return x -class Stretch2d(nn.Module): - def __init__(self, x_scale, y_scale): - super().__init__() - self.x_scale = x_scale - self.y_scale = y_scale - - def forward(self, x): - b, c, h, w = x.size() - x = x.unsqueeze(-1).unsqueeze(3) - x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) - return x.view(b, c, h * self.y_scale, w * self.x_scale) - - class UpsampleNetwork(nn.Module): def __init__( self, From 69a599d403eb44140f8a640241faee4c551fda00 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 12:12:50 +0100 Subject: [PATCH 13/25] refactor(freevc): remove duplicate code --- TTS/tts/layers/vits/discriminator.py | 4 +- TTS/vc/models/freevc.py | 75 ++-------------------------- TTS/vc/modules/freevc/commons.py | 28 +---------- TTS/vc/modules/freevc/modules.py | 4 +- 4 files changed, 8 insertions(+), 103 deletions(-) diff --git a/TTS/tts/layers/vits/discriminator.py b/TTS/tts/layers/vits/discriminator.py index 3449739fdc..49f7a0d074 100644 --- a/TTS/tts/layers/vits/discriminator.py +++ b/TTS/tts/layers/vits/discriminator.py @@ -2,7 +2,7 @@ from torch import nn from torch.nn.modules.conv import Conv1d -from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP +from TTS.vocoder.models.hifigan_discriminator import LRELU_SLOPE, DiscriminatorP class DiscriminatorS(torch.nn.Module): @@ -39,7 +39,7 @@ def forward(self, x): feat = [] for l in self.convs: x = l(x) - x = torch.nn.functional.leaky_relu(x, 0.1) + x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) feat.append(x) x = self.conv_post(x) feat.append(x) diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index e5cfdc1e61..62559de534 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -6,15 +6,15 @@ import torch from coqpit import Coqpit from torch import nn -from torch.nn import Conv1d, Conv2d, ConvTranspose1d +from torch.nn import Conv1d, ConvTranspose1d from torch.nn import functional as F -from torch.nn.utils import spectral_norm from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec import TTS.vc.modules.freevc.commons as commons import TTS.vc.modules.freevc.modules as modules +from TTS.tts.layers.vits.discriminator import DiscriminatorS from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.vc.configs.freevc_config import FreeVCConfig @@ -23,7 +23,7 @@ from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx from TTS.vc.modules.freevc.wavlm import get_wavlm -from TTS.vocoder.models.hifigan_generator import get_padding +from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP logger = logging.getLogger(__name__) @@ -164,75 +164,6 @@ def remove_weight_norm(self): remove_parametrizations(l, "weight") -class DiscriminatorP(torch.nn.Module): - def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): - super(DiscriminatorP, self).__init__() - self.period = period - self.use_spectral_norm = use_spectral_norm - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), - norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), - ] - ) - self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) - - def forward(self, x): - fmap = [] - - # 1d to 2d - b, c, t = x.shape - if t % self.period != 0: # pad first - n_pad = self.period - (t % self.period) - x = F.pad(x, (0, n_pad), "reflect") - t = t + n_pad - x = x.view(b, c, t // self.period, self.period) - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - -class DiscriminatorS(torch.nn.Module): - def __init__(self, use_spectral_norm=False): - super(DiscriminatorS, self).__init__() - norm_f = weight_norm if use_spectral_norm is False else spectral_norm - self.convs = nn.ModuleList( - [ - norm_f(Conv1d(1, 16, 15, 1, padding=7)), - norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), - norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), - norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), - norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), - norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), - ] - ) - self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) - - def forward(self, x): - fmap = [] - - for l in self.convs: - x = l(x) - x = F.leaky_relu(x, modules.LRELU_SLOPE) - fmap.append(x) - x = self.conv_post(x) - fmap.append(x) - x = torch.flatten(x, 1, -1) - - return x, fmap - - class MultiPeriodDiscriminator(torch.nn.Module): def __init__(self, use_spectral_norm=False): super(MultiPeriodDiscriminator, self).__init__() diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/modules/freevc/commons.py index feea7f34dc..49889e4816 100644 --- a/TTS/vc/modules/freevc/commons.py +++ b/TTS/vc/modules/freevc/commons.py @@ -3,7 +3,7 @@ import torch from torch.nn import functional as F -from TTS.tts.utils.helpers import convert_pad_shape, sequence_mask +from TTS.tts.utils.helpers import convert_pad_shape def init_weights(m: torch.nn.Module, mean: float = 0.0, std: float = 0.01) -> None: @@ -96,37 +96,11 @@ def subsequent_mask(length): return mask -@torch.jit.script -def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): - n_channels_int = n_channels[0] - in_act = input_a + input_b - t_act = torch.tanh(in_act[:, :n_channels_int, :]) - s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) - acts = t_act * s_act - return acts - - def shift_1d(x): x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] return x -def generate_path(duration, mask): - """ - duration: [b, 1, t_x] - mask: [b, 1, t_y, t_x] - """ - b, _, t_y, t_x = mask.shape - cum_duration = torch.cumsum(duration, -1) - - cum_duration_flat = cum_duration.view(b * t_x) - path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) - path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path.unsqueeze(1).transpose(2, 3) * mask - return path - - def clip_grad_value_(parameters, clip_value, norm_type=2): if isinstance(parameters, torch.Tensor): parameters = [parameters] diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/modules/freevc/modules.py index 722444a303..ea17be24d6 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/modules/freevc/modules.py @@ -5,8 +5,8 @@ from torch.nn.utils.parametrizations import weight_norm from torch.nn.utils.parametrize import remove_parametrizations -import TTS.vc.modules.freevc.commons as commons from TTS.tts.layers.generic.normalization import LayerNorm2 +from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply from TTS.vc.modules.freevc.commons import init_weights from TTS.vocoder.models.hifigan_generator import get_padding @@ -99,7 +99,7 @@ def forward(self, x, x_mask, g=None, **kwargs): else: g_l = torch.zeros_like(x_in) - acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) acts = self.drop(acts) res_skip_acts = self.res_skip_layers[i](acts) From 6ecf47312c1cbe30dd47bd70c8ae30dbc9d2d407 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 15:38:35 +0100 Subject: [PATCH 14/25] refactor(xtts): use tortoise conditioning encoder --- TTS/tts/layers/tortoise/autoregressive.py | 12 +++++------ TTS/tts/layers/xtts/gpt.py | 16 ++------------- TTS/tts/layers/xtts/latent_encoder.py | 25 ----------------------- TTS/tts/models/xtts.py | 19 ----------------- 4 files changed, 7 insertions(+), 65 deletions(-) diff --git a/TTS/tts/layers/tortoise/autoregressive.py b/TTS/tts/layers/tortoise/autoregressive.py index 19c1adc0a6..07cf3d542b 100644 --- a/TTS/tts/layers/tortoise/autoregressive.py +++ b/TTS/tts/layers/tortoise/autoregressive.py @@ -176,7 +176,6 @@ def __init__( embedding_dim, attn_blocks=6, num_attn_heads=4, - mean=False, ): super().__init__() attn = [] @@ -185,15 +184,14 @@ def __init__( attn.append(AttentionBlock(embedding_dim, num_attn_heads)) self.attn = nn.Sequential(*attn) self.dim = embedding_dim - self.mean = mean def forward(self, x): + """ + x: (b, 80, s) + """ h = self.init(x) h = self.attn(h) - if self.mean: - return h.mean(dim=2) - else: - return h[:, :, 0] + return h class LearnedPositionEmbeddings(nn.Module): @@ -473,7 +471,7 @@ def get_conditioning(self, speech_conditioning_input): ) conds = [] for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) + conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])[:, :, 0]) conds = torch.stack(conds, dim=1) conds = conds.mean(dim=1) return conds diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index 899522e091..20eff26ecc 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -8,12 +8,12 @@ from transformers import GPT2Config from TTS.tts.layers.tortoise.autoregressive import ( + ConditioningEncoder, LearnedPositionEmbeddings, _prepare_attention_mask_for_generation, build_hf_gpt_transformer, ) from TTS.tts.layers.xtts.gpt_inference import GPT2InferenceModel -from TTS.tts.layers.xtts.latent_encoder import ConditioningEncoder from TTS.tts.layers.xtts.perceiver_encoder import PerceiverResampler @@ -235,19 +235,6 @@ def get_logits( else: return first_logits - def get_conditioning(self, speech_conditioning_input): - speech_conditioning_input = ( - speech_conditioning_input.unsqueeze(1) - if len(speech_conditioning_input.shape) == 3 - else speech_conditioning_input - ) - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.conditioning_encoder(speech_conditioning_input[:, j])) - conds = torch.stack(conds, dim=1) - conds = conds.mean(dim=1) - return conds - def get_prompts(self, prompt_codes): """ Create a prompt from the mel codes. This is used to condition the model on the mel codes. @@ -286,6 +273,7 @@ def get_style_emb(self, cond_input, return_latent=False): """ cond_input: (b, 80, s) or (b, 1, 80, s) conds: (b, 1024, s) + output: (b, 1024, 32) """ conds = None if not return_latent: diff --git a/TTS/tts/layers/xtts/latent_encoder.py b/TTS/tts/layers/xtts/latent_encoder.py index 7d385ec46a..6becffb8b7 100644 --- a/TTS/tts/layers/xtts/latent_encoder.py +++ b/TTS/tts/layers/xtts/latent_encoder.py @@ -93,28 +93,3 @@ def forward(self, x, mask=None, qk_bias=0): h = self.proj_out(h) xp = self.x_proj(x) return (xp + h).reshape(b, xp.shape[1], *spatial) - - -class ConditioningEncoder(nn.Module): - def __init__( - self, - spec_dim, - embedding_dim, - attn_blocks=6, - num_attn_heads=4, - ): - super().__init__() - attn = [] - self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1) - for a in range(attn_blocks): - attn.append(AttentionBlock(embedding_dim, num_attn_heads)) - self.attn = nn.Sequential(*attn) - self.dim = embedding_dim - - def forward(self, x): - """ - x: (b, 80, s) - """ - h = self.init(x) - h = self.attn(h) - return h diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 22d2720efa..35de91e359 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -93,25 +93,6 @@ def load_audio(audiopath, sampling_rate): return audio -def pad_or_truncate(t, length): - """ - Ensure a given tensor t has a specified sequence length by either padding it with zeros or clipping it. - - Args: - t (torch.Tensor): The input tensor to be padded or truncated. - length (int): The desired length of the tensor. - - Returns: - torch.Tensor: The padded or truncated tensor. - """ - tp = t[..., :length] - if t.shape[-1] == length: - tp = t - elif t.shape[-1] < length: - tp = F.pad(t, (0, length - t.shape[-1])) - return tp - - @dataclass class XttsAudioConfig(Coqpit): """ From 0f69d31f705dd28c132ea86f4dc1ab47d1d9efb0 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 17:28:30 +0100 Subject: [PATCH 15/25] refactor(vocoder): remove duplicate function --- TTS/vocoder/models/parallel_wavegan_generator.py | 16 ++++++++-------- TTS/vocoder/models/univnet_generator.py | 10 ++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 6a4d4ca6e7..e60c8781f0 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -12,6 +12,13 @@ logger = logging.getLogger(__name__) +def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + class ParallelWaveganGenerator(torch.nn.Module): """PWGAN generator as in https://arxiv.org/pdf/1910.11480.pdf. It is similar to WaveNet with no causal convolution. @@ -144,16 +151,9 @@ def _apply_weight_norm(m): self.apply(_apply_weight_norm) - @staticmethod - def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): - assert layers % stacks == 0 - layers_per_cycle = layers // stacks - dilations = [dilation(i % layers_per_cycle) for i in range(layers)] - return (kernel_size - 1) * sum(dilations) + 1 - @property def receptive_field_size(self): - return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size) def load_checkpoint( self, config, checkpoint_path, eval=False, cache=False diff --git a/TTS/vocoder/models/univnet_generator.py b/TTS/vocoder/models/univnet_generator.py index 72e57a9c39..5d1f817927 100644 --- a/TTS/vocoder/models/univnet_generator.py +++ b/TTS/vocoder/models/univnet_generator.py @@ -7,6 +7,7 @@ from torch.nn.utils import parametrize from TTS.vocoder.layers.lvc_block import LVCBlock +from TTS.vocoder.models.parallel_wavegan_generator import _get_receptive_field_size logger = logging.getLogger(__name__) @@ -133,17 +134,10 @@ def _apply_weight_norm(m): self.apply(_apply_weight_norm) - @staticmethod - def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2**x): - assert layers % stacks == 0 - layers_per_cycle = layers // stacks - dilations = [dilation(i % layers_per_cycle) for i in range(layers)] - return (kernel_size - 1) * sum(dilations) + 1 - @property def receptive_field_size(self): """Return receptive field size.""" - return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + return _get_receptive_field_size(self.layers, self.stacks, self.kernel_size) @torch.no_grad() def inference(self, c): From fa844e0fb7ea84a26cea1fc5ae3b3c3b7f811f55 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 21:35:26 +0100 Subject: [PATCH 16/25] refactor(tacotron): remove duplicate function --- TTS/tts/layers/tacotron/capacitron_layers.py | 11 +++-------- TTS/tts/layers/tacotron/common_layers.py | 7 +++++++ TTS/tts/layers/tacotron/gst_layers.py | 11 +++-------- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/TTS/tts/layers/tacotron/capacitron_layers.py b/TTS/tts/layers/tacotron/capacitron_layers.py index 2181ffa7ec..817f42771b 100644 --- a/TTS/tts/layers/tacotron/capacitron_layers.py +++ b/TTS/tts/layers/tacotron/capacitron_layers.py @@ -3,6 +3,8 @@ from torch.distributions.multivariate_normal import MultivariateNormal as MVN from torch.nn import functional as F +from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height + class CapacitronVAE(nn.Module): """Effective Use of Variational Embedding Capacity for prosody transfer. @@ -97,7 +99,7 @@ def __init__(self, num_mel, out_dim): self.training = False self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) - post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) + post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 2, num_layers) self.recurrence = nn.LSTM( input_size=filters[-1] * post_conv_height, hidden_size=out_dim, batch_first=True, bidirectional=False ) @@ -155,13 +157,6 @@ def forward(self, inputs, input_lengths): return last_output.to(inputs.device) # [B, 128] - @staticmethod - def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): - """Height of spec after n convolutions with fixed kernel/stride/pad.""" - for _ in range(n_convs): - height = (height - kernel_size + 2 * pad) // stride + 1 - return height - class TextSummary(nn.Module): def __init__(self, embedding_dim, encoder_output_dim): diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py index f78ff1e75f..16e517fdca 100644 --- a/TTS/tts/layers/tacotron/common_layers.py +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -3,6 +3,13 @@ from torch.nn import functional as F +def calculate_post_conv_height(height: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: + """Height of spec after n convolutions with fixed kernel/stride/pad.""" + for _ in range(n_convs): + height = (height - kernel_size + 2 * pad) // stride + 1 + return height + + class Linear(nn.Module): """Linear layer with a specific initialization. diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index ac3d7d4aae..4a83fb1c83 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -2,6 +2,8 @@ import torch.nn.functional as F from torch import nn +from TTS.tts.layers.tacotron.common_layers import calculate_post_conv_height + class GST(nn.Module): """Global Style Token Module for factorizing prosody in speech. @@ -44,7 +46,7 @@ def __init__(self, num_mel, embedding_dim): self.convs = nn.ModuleList(convs) self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) - post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) + post_conv_height = calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) self.recurrence = nn.GRU( input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True ) @@ -71,13 +73,6 @@ def forward(self, inputs): return out.squeeze(0) - @staticmethod - def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): - """Height of spec after n convolutions with fixed kernel/stride/pad.""" - for _ in range(n_convs): - height = (height - kernel_size + 2 * pad) // stride + 1 - return height - class StyleTokenLayer(nn.Module): """NN Module attending to style tokens based on prosody encodings.""" From b45a7a4220e21eea4825d24ba4498afb37591c64 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 22:02:26 +0100 Subject: [PATCH 17/25] refactor: move exists() and default() into generic_utils --- TTS/tts/layers/bark/hubert/kmeans_hubert.py | 10 ++-------- TTS/tts/layers/tortoise/clvp.py | 4 ---- TTS/tts/layers/tortoise/transformer.py | 17 +++++++--------- TTS/tts/layers/tortoise/xtransformers.py | 22 +++------------------ TTS/tts/layers/xtts/dvae.py | 4 ---- TTS/tts/layers/xtts/perceiver_encoder.py | 11 +---------- TTS/utils/generic_utils.py | 15 +++++++++++++- 7 files changed, 27 insertions(+), 56 deletions(-) diff --git a/TTS/tts/layers/bark/hubert/kmeans_hubert.py b/TTS/tts/layers/bark/hubert/kmeans_hubert.py index 58a614cb87..ade84794eb 100644 --- a/TTS/tts/layers/bark/hubert/kmeans_hubert.py +++ b/TTS/tts/layers/bark/hubert/kmeans_hubert.py @@ -14,6 +14,8 @@ from torchaudio.functional import resample from transformers import HubertModel +from TTS.utils.generic_utils import exists + def round_down_nearest_multiple(num, divisor): return num // divisor * divisor @@ -26,14 +28,6 @@ def curtail_to_multiple(t, mult, from_left=False): return t[..., seq_slice] -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d - - class CustomHubert(nn.Module): """ checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert diff --git a/TTS/tts/layers/tortoise/clvp.py b/TTS/tts/layers/tortoise/clvp.py index 241dfdd4f4..44da1324e7 100644 --- a/TTS/tts/layers/tortoise/clvp.py +++ b/TTS/tts/layers/tortoise/clvp.py @@ -8,10 +8,6 @@ from TTS.tts.layers.tortoise.xtransformers import Encoder -def exists(val): - return val is not None - - def masked_mean(t, mask, dim=1): t = t.masked_fill(~mask[:, :, None], 0.0) return t.sum(dim=1) / mask.sum(dim=1)[..., None] diff --git a/TTS/tts/layers/tortoise/transformer.py b/TTS/tts/layers/tortoise/transformer.py index 6cb1bab96a..ed4d79d4ab 100644 --- a/TTS/tts/layers/tortoise/transformer.py +++ b/TTS/tts/layers/tortoise/transformer.py @@ -1,22 +1,19 @@ +from typing import TypeVar, Union + import torch import torch.nn.functional as F from einops import rearrange from torch import nn -# helpers - +from TTS.utils.generic_utils import exists -def exists(val): - return val is not None - - -def default(val, d): - return val if exists(val) else d +# helpers +_T = TypeVar("_T") -def cast_tuple(val, depth=1): +def cast_tuple(val: Union[tuple[_T], list[_T], _T], depth: int = 1) -> tuple[_T]: if isinstance(val, list): - val = tuple(val) + return tuple(val) return val if isinstance(val, tuple) else (val,) * depth diff --git a/TTS/tts/layers/tortoise/xtransformers.py b/TTS/tts/layers/tortoise/xtransformers.py index 9325b8c720..0892fee19d 100644 --- a/TTS/tts/layers/tortoise/xtransformers.py +++ b/TTS/tts/layers/tortoise/xtransformers.py @@ -1,13 +1,15 @@ import math from collections import namedtuple from functools import partial -from inspect import isfunction import torch import torch.nn.functional as F from einops import rearrange, repeat from torch import einsum, nn +from TTS.tts.layers.tortoise.transformer import cast_tuple, max_neg_value +from TTS.utils.generic_utils import default, exists + DEFAULT_DIM_HEAD = 64 Intermediates = namedtuple("Intermediates", ["pre_softmax_attn", "post_softmax_attn"]) @@ -25,20 +27,6 @@ # helpers -def exists(val): - return val is not None - - -def default(val, d): - if exists(val): - return val - return d() if isfunction(d) else d - - -def cast_tuple(val, depth): - return val if isinstance(val, tuple) else (val,) * depth - - class always: def __init__(self, val): self.val = val @@ -63,10 +51,6 @@ def __call__(self, x, *args, **kwargs): return x == self.val -def max_neg_value(tensor): - return -torch.finfo(tensor.dtype).max - - def l2norm(t): return F.normalize(t, p=2, dim=-1) diff --git a/TTS/tts/layers/xtts/dvae.py b/TTS/tts/layers/xtts/dvae.py index 73970fb0bf..4f806f82cb 100644 --- a/TTS/tts/layers/xtts/dvae.py +++ b/TTS/tts/layers/xtts/dvae.py @@ -14,10 +14,6 @@ logger = logging.getLogger(__name__) -def default(val, d): - return val if val is not None else d - - def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training diff --git a/TTS/tts/layers/xtts/perceiver_encoder.py b/TTS/tts/layers/xtts/perceiver_encoder.py index 4b42a0e467..7477087283 100644 --- a/TTS/tts/layers/xtts/perceiver_encoder.py +++ b/TTS/tts/layers/xtts/perceiver_encoder.py @@ -10,10 +10,7 @@ from torch import einsum, nn from TTS.tts.layers.tortoise.transformer import GEGLU - - -def exists(val): - return val is not None +from TTS.utils.generic_utils import default, exists def once(fn): @@ -153,12 +150,6 @@ def Sequential(*mods): return nn.Sequential(*filter(exists, mods)) -def default(val, d): - if exists(val): - return val - return d() if callable(d) else d - - class RMSNorm(nn.Module): def __init__(self, dim, scale=True, dim_cond=None): super().__init__() diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index c38282248d..087ae7d0e1 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -4,13 +4,26 @@ import logging import re from pathlib import Path -from typing import Dict, Optional +from typing import Callable, Dict, Optional, TypeVar, Union import torch from packaging.version import Version +from typing_extensions import TypeIs logger = logging.getLogger(__name__) +_T = TypeVar("_T") + + +def exists(val: Union[_T, None]) -> TypeIs[_T]: + return val is not None + + +def default(val: Union[_T, None], d: Union[_T, Callable[[], _T]]) -> _T: + if exists(val): + return val + return d() if callable(d) else d + def to_camel(text): text = text.capitalize() From 54f4228a466bf3042e7d6d56b00767559d66b942 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 22:08:05 +0100 Subject: [PATCH 18/25] refactor(xtts): use existing cleaners --- TTS/tts/layers/xtts/tokenizer.py | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index e87eb0766b..076727239c 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -15,6 +15,7 @@ from tokenizers import Tokenizer from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +from TTS.tts.utils.text.cleaners import collapse_whitespace, lowercase logger = logging.getLogger(__name__) @@ -72,8 +73,6 @@ def split_sentence(text, lang, text_split_length=250): return text_splits -_whitespace_re = re.compile(r"\s+") - # List of (regular expression, replacement) pairs for abbreviations: _abbreviations = { "en": [ @@ -564,14 +563,6 @@ def expand_numbers_multilingual(text, lang="en"): return text -def lowercase(text): - return text.lower() - - -def collapse_whitespace(text): - return re.sub(_whitespace_re, " ", text) - - def multilingual_cleaners(text, lang): text = text.replace('"', "") if lang == "tr": @@ -586,13 +577,6 @@ def multilingual_cleaners(text, lang): return text -def basic_cleaners(text): - """Basic pipeline that lowercases and collapses whitespace without transliteration.""" - text = lowercase(text) - text = collapse_whitespace(text) - return text - - def chinese_transliterate(text): try: import pypinyin From b1ac884e077e243015c59421c8385e55a61d0899 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 22:33:25 +0100 Subject: [PATCH 19/25] refactor: move shared function into dataset.py --- TTS/tts/datasets/dataset.py | 25 +++++++++++++++++++++++++ TTS/tts/models/delightful_tts.py | 21 +-------------------- TTS/tts/models/vits.py | 26 +------------------------- 3 files changed, 27 insertions(+), 45 deletions(-) diff --git a/TTS/tts/datasets/dataset.py b/TTS/tts/datasets/dataset.py index 37e3a1779d..5f629f32a9 100644 --- a/TTS/tts/datasets/dataset.py +++ b/TTS/tts/datasets/dataset.py @@ -63,6 +63,31 @@ def get_audio_size(audiopath: Union[str, os.PathLike[Any]]) -> int: raise RuntimeError(msg) from e +def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: Optional[dict] = None): + """Create inverse frequency weights for balancing the dataset. + + Use `multi_dict` to scale relative weights.""" + attr_names_samples = np.array([item[attr_name] for item in items]) + unique_attr_names = np.unique(attr_names_samples).tolist() + attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] + attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) + weight_attr = 1.0 / attr_count + dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) + dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) + if multi_dict is not None: + # check if all keys are in the multi_dict + for k in multi_dict: + assert k in unique_attr_names, f"{k} not in {unique_attr_names}" + # scale weights + multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) + dataset_samples_weight *= multiplier_samples + return ( + torch.from_numpy(dataset_samples_weight).float(), + unique_attr_names, + np.unique(dataset_samples_weight).tolist(), + ) + + class TTSDataset(Dataset): def __init__( self, diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 7216e8143a..8857004725 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -17,7 +17,7 @@ from trainer.torch import DistributedSampler, DistributedSamplerWrapper from trainer.trainer_utils import get_optimizer, get_scheduler -from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample +from TTS.tts.datasets.dataset import F0Dataset, TTSDataset, _parse_sample, get_attribute_balancer_weights from TTS.tts.layers.delightful_tts.acoustic_model import AcousticModel from TTS.tts.layers.losses import ( ForwardSumLoss, @@ -194,25 +194,6 @@ def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fm ############################## -def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): - """Create balancer weight for torch WeightedSampler""" - attr_names_samples = np.array([item[attr_name] for item in items]) - unique_attr_names = np.unique(attr_names_samples).tolist() - attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] - attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) - weight_attr = 1.0 / attr_count - dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) - dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) - if multi_dict is not None: - multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) - dataset_samples_weight *= multiplier_samples - return ( - torch.from_numpy(dataset_samples_weight).float(), - unique_attr_names, - np.unique(dataset_samples_weight).tolist(), - ) - - class ForwardTTSE2eF0Dataset(F0Dataset): """Override F0Dataset to avoid slow computing of pitches""" diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index aea0f4e4f8..30d9caff02 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -21,7 +21,7 @@ from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig -from TTS.tts.datasets.dataset import TTSDataset, _parse_sample +from TTS.tts.datasets.dataset import TTSDataset, _parse_sample, get_attribute_balancer_weights from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.vits.discriminator import VitsDiscriminator from TTS.tts.layers.vits.networks import PosteriorEncoder, ResidualCouplingBlocks, TextEncoder @@ -219,30 +219,6 @@ class VitsAudioConfig(Coqpit): ############################## -def get_attribute_balancer_weights(items: list, attr_name: str, multi_dict: dict = None): - """Create inverse frequency weights for balancing the dataset. - Use `multi_dict` to scale relative weights.""" - attr_names_samples = np.array([item[attr_name] for item in items]) - unique_attr_names = np.unique(attr_names_samples).tolist() - attr_idx = [unique_attr_names.index(l) for l in attr_names_samples] - attr_count = np.array([len(np.where(attr_names_samples == l)[0]) for l in unique_attr_names]) - weight_attr = 1.0 / attr_count - dataset_samples_weight = np.array([weight_attr[l] for l in attr_idx]) - dataset_samples_weight = dataset_samples_weight / np.linalg.norm(dataset_samples_weight) - if multi_dict is not None: - # check if all keys are in the multi_dict - for k in multi_dict: - assert k in unique_attr_names, f"{k} not in {unique_attr_names}" - # scale weights - multiplier_samples = np.array([multi_dict.get(item[attr_name], 1.0) for item in items]) - dataset_samples_weight *= multiplier_samples - return ( - torch.from_numpy(dataset_samples_weight).float(), - unique_attr_names, - np.unique(dataset_samples_weight).tolist(), - ) - - class VitsDataset(TTSDataset): def __init__(self, model_args, *args, **kwargs): super().__init__(*args, **kwargs) From 2c82477a785dad0abe178453af70c0554dc7982f Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 22 Nov 2024 23:44:40 +0100 Subject: [PATCH 20/25] ci: merge integration tests back into unit tests --- .github/actions/setup-uv/action.yml | 4 +- .github/workflows/integration-tests.yml | 82 ------------------------- .github/workflows/tests.yml | 51 ++++++++++++++- 3 files changed, 50 insertions(+), 87 deletions(-) delete mode 100644 .github/workflows/integration-tests.yml diff --git a/.github/actions/setup-uv/action.yml b/.github/actions/setup-uv/action.yml index 619b138fb2..c7dd4f5f99 100644 --- a/.github/actions/setup-uv/action.yml +++ b/.github/actions/setup-uv/action.yml @@ -4,8 +4,8 @@ runs: using: 'composite' steps: - name: Install uv - uses: astral-sh/setup-uv@v3 + uses: astral-sh/setup-uv@v4 with: - version: "0.5.1" + version: "0.5.4" enable-cache: true cache-dependency-glob: "**/pyproject.toml" diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml deleted file mode 100644 index 4dc8c76c1a..0000000000 --- a/.github/workflows/integration-tests.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: integration - -on: - push: - branches: - - main - pull_request: - types: [opened, synchronize, reopened] - workflow_dispatch: - inputs: - trainer_branch: - description: "Branch of Trainer to test" - required: false - default: "main" - coqpit_branch: - description: "Branch of Coqpit to test" - required: false - default: "main" -jobs: - test: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.12"] - subset: ["test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"] - steps: - - uses: actions/checkout@v4 - - name: Setup uv - uses: ./.github/actions/setup-uv - - name: Set up Python ${{ matrix.python-version }} - run: uv python install ${{ matrix.python-version }} - - name: Install Espeak - if: contains(fromJSON('["test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset) - run: | - sudo apt-get update - sudo apt-get install espeak espeak-ng - - name: Install dependencies - run: | - sudo apt-get update - sudo apt-get install -y --no-install-recommends git make gcc - make system-deps - - name: Install custom Trainer and/or Coqpit if requested - run: | - if [[ -n "${{ github.event.inputs.trainer_branch }}" ]]; then - uv add git+https://github.com/idiap/coqui-ai-Trainer --branch ${{ github.event.inputs.trainer_branch }} - fi - if [[ -n "${{ github.event.inputs.coqpit_branch }}" ]]; then - uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }} - fi - - name: Integration tests - run: | - resolution=highest - if [ "${{ matrix.python-version }}" == "3.9" ]; then - resolution=lowest-direct - fi - uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }} - - name: Upload coverage data - uses: actions/upload-artifact@v4 - with: - include-hidden-files: true - name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }} - path: .coverage.* - if-no-files-found: ignore - coverage: - if: always() - needs: test - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - name: Setup uv - uses: ./.github/actions/setup-uv - - uses: actions/download-artifact@v4 - with: - pattern: coverage-data-* - merge-multiple: true - - name: Combine coverage - run: | - uv python install - uvx coverage combine - uvx coverage html --skip-covered --skip-empty - uvx coverage report --format=markdown >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 576de150fd..8d639d5dee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: unit +name: test on: push: @@ -17,7 +17,7 @@ on: required: false default: "main" jobs: - test: + unit: runs-on: ubuntu-latest strategy: fail-fast: false @@ -62,9 +62,54 @@ jobs: name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }} path: .coverage.* if-no-files-found: ignore + integration: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.12"] + subset: ["test_tts", "test_tts2", "test_vocoder", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"] + steps: + - uses: actions/checkout@v4 + - name: Setup uv + uses: ./.github/actions/setup-uv + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + - name: Install Espeak + if: contains(fromJSON('["test_tts", "test_tts2", "test_xtts", "test_zoo0", "test_zoo1", "test_zoo2"]'), matrix.subset) + run: | + sudo apt-get update + sudo apt-get install espeak espeak-ng + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y --no-install-recommends git make gcc + make system-deps + - name: Install custom Trainer and/or Coqpit if requested + run: | + if [[ -n "${{ github.event.inputs.trainer_branch }}" ]]; then + uv add git+https://github.com/idiap/coqui-ai-Trainer --branch ${{ github.event.inputs.trainer_branch }} + fi + if [[ -n "${{ github.event.inputs.coqpit_branch }}" ]]; then + uv add git+https://github.com/idiap/coqui-ai-coqpit --branch ${{ github.event.inputs.coqpit_branch }} + fi + - name: Integration tests + run: | + resolution=highest + if [ "${{ matrix.python-version }}" == "3.9" ]; then + resolution=lowest-direct + fi + uv run --resolution=$resolution --extra server --extra languages make ${{ matrix.subset }} + - name: Upload coverage data + uses: actions/upload-artifact@v4 + with: + include-hidden-files: true + name: coverage-data-${{ matrix.subset }}-${{ matrix.python-version }} + path: .coverage.* + if-no-files-found: ignore coverage: if: always() - needs: test + needs: [unit, integration] runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 From 76df6421dead004a40b1ded1b12916282f013132 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sat, 23 Nov 2024 01:16:50 +0100 Subject: [PATCH 21/25] refactor: move more audio processing into torch_transforms --- TTS/tts/models/delightful_tts.py | 139 +----------------------- TTS/tts/models/vits.py | 126 +-------------------- TTS/utils/audio/torch_transforms.py | 96 ++++++++++++++++ TTS/vc/modules/freevc/mel_processing.py | 48 -------- tests/tts_tests/test_vits.py | 5 +- 5 files changed, 100 insertions(+), 314 deletions(-) diff --git a/TTS/tts/models/delightful_tts.py b/TTS/tts/models/delightful_tts.py index 8857004725..e6db116081 100644 --- a/TTS/tts/models/delightful_tts.py +++ b/TTS/tts/models/delightful_tts.py @@ -9,7 +9,6 @@ import torch import torch.distributed as dist from coqpit import Coqpit -from librosa.filters import mel as librosa_mel_fn from torch import nn from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler @@ -38,7 +37,7 @@ from TTS.utils.audio.numpy_transforms import db_to_amp as db_to_amp_numpy from TTS.utils.audio.numpy_transforms import mel_to_wav as mel_to_wav_numpy from TTS.utils.audio.processor import AudioProcessor -from TTS.utils.audio.torch_transforms import amp_to_db +from TTS.utils.audio.torch_transforms import wav_to_mel, wav_to_spec from TTS.vocoder.layers.losses import MultiScaleSTFTLoss from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -50,145 +49,11 @@ mel_basis = {} -def _wav_to_spec(y, n_fft, hop_length, win_length, center=False): - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global hann_window # pylint: disable=global-statement - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - return spec - - -def wav_to_spec(y, n_fft, hop_length, win_length, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - def wav_to_energy(y, n_fft, hop_length, win_length, center=False): - spec = _wav_to_spec(y, n_fft, hop_length, win_length, center=center) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center) return torch.norm(spec, dim=1, keepdim=True) -def name_mel_basis(spec, n_fft, fmax): - n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" - return n_fft_len - - -def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): - """ - Args Shapes: - - spec : :math:`[B,C,T]` - - Return Shapes: - - mel : :math:`[B,C,T]` - """ - global mel_basis # pylint: disable=global-statement - mel_basis_key = name_mel_basis(spec, n_fft, fmax) - # pylint: disable=too-many-function-args - if mel_basis_key not in mel_basis: - # pylint: disable=missing-kwoa - mel = librosa_mel_fn(sample_rate, n_fft, num_mels, fmin, fmax) - mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - mel = torch.matmul(mel_basis[mel_basis_key], spec) - mel = amp_to_db(mel) - return mel - - -def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T_y]` - - Return Shapes: - - spec : :math:`[B,C,T_spec]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global mel_basis, hann_window # pylint: disable=global-statement - mel_basis_key = name_mel_basis(y, n_fft, fmax) - wnsize_dtype_device = str(win_length) + "_" + str(y.dtype) + "_" + str(y.device) - if mel_basis_key not in mel_basis: - # pylint: disable=missing-kwoa - mel = librosa_mel_fn( - sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax - ) # pylint: disable=too-many-function-args - mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.matmul(mel_basis[mel_basis_key], spec) - spec = amp_to_db(spec) - return spec - - ############################## # DATASET ############################## diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 30d9caff02..7ec2519236 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -10,7 +10,6 @@ import torch.distributed as dist import torchaudio from coqpit import Coqpit -from librosa.filters import mel as librosa_mel_fn from monotonic_alignment_search import maximum_path from torch import nn from torch.nn import functional as F @@ -35,7 +34,7 @@ from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from TTS.utils.audio.torch_transforms import amp_to_db +from TTS.utils.audio.torch_transforms import spec_to_mel, wav_to_mel, wav_to_spec from TTS.utils.samplers import BucketBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -46,10 +45,6 @@ # IO / Feature extraction ############################## -# pylint: disable=global-statement -hann_window = {} -mel_basis = {} - @torch.no_grad() def weights_reset(m: nn.Module): @@ -79,125 +74,6 @@ def load_audio(file_path): return x, sr -def wav_to_spec(y, n_fft, hop_length, win_length, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -def spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): - """ - Args Shapes: - - spec : :math:`[B,C,T]` - - Return Shapes: - - mel : :math:`[B,C,T]` - """ - global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - mel = torch.matmul(mel_basis[fmax_dtype_device], spec) - mel = amp_to_db(mel) - return mel - - -def wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): - """ - Args Shapes: - - y : :math:`[B, 1, T]` - - Return Shapes: - - spec : :math:`[B,C,T]` - """ - y = y.squeeze(1) - - if torch.min(y) < -1.0: - logger.info("min value is %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("max value is %.3f", torch.max(y)) - - global mel_basis, hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - wnsize_dtype_device = str(win_length) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), - (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), - mode="reflect", - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_length, - win_length=win_length, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = amp_to_db(spec) - return spec - - ############################# # CONFIGS ############################# diff --git a/TTS/utils/audio/torch_transforms.py b/TTS/utils/audio/torch_transforms.py index dda4c0a419..59bb23cc4f 100644 --- a/TTS/utils/audio/torch_transforms.py +++ b/TTS/utils/audio/torch_transforms.py @@ -1,7 +1,15 @@ +import logging + import librosa import torch from torch import nn +logger = logging.getLogger(__name__) + + +hann_window = {} +mel_basis = {} + def amp_to_db(x: torch.Tensor, *, spec_gain: float = 1.0, clip_val: float = 1e-5) -> torch.Tensor: """Spectral normalization / dynamic range compression.""" @@ -13,6 +21,94 @@ def db_to_amp(x: torch.Tensor, *, spec_gain: float = 1.0) -> torch.Tensor: return torch.exp(x) / spec_gain +def wav_to_spec(y: torch.Tensor, n_fft: int, hop_length: int, win_length: int, *, center: bool = False) -> torch.Tensor: + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + y = y.squeeze(1) + + if torch.min(y) < -1.0: + logger.info("min value is %.3f", torch.min(y)) + if torch.max(y) > 1.0: + logger.info("max value is %.3f", torch.max(y)) + + global hann_window + wnsize_dtype_device = f"{win_length}_{y.dtype}_{y.device}" + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.view_as_real( + torch.stft( + y, + n_fft, + hop_length=hop_length, + win_length=win_length, + window=hann_window[wnsize_dtype_device], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + ) + + return torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + +def spec_to_mel( + spec: torch.Tensor, n_fft: int, num_mels: int, sample_rate: int, fmin: float, fmax: float +) -> torch.Tensor: + """ + Args Shapes: + - spec : :math:`[B,C,T]` + + Return Shapes: + - mel : :math:`[B,C,T]` + """ + global mel_basis + fmax_dtype_device = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" + if fmax_dtype_device not in mel_basis: + # TODO: switch librosa to torchaudio + mel = librosa.filters.mel(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + mel = torch.matmul(mel_basis[fmax_dtype_device], spec) + return amp_to_db(mel) + + +def wav_to_mel( + y: torch.Tensor, + n_fft: int, + num_mels: int, + sample_rate: int, + hop_length: int, + win_length: int, + fmin: float, + fmax: float, + *, + center: bool = False, +) -> torch.Tensor: + """ + Args Shapes: + - y : :math:`[B, 1, T]` + + Return Shapes: + - spec : :math:`[B,C,T]` + """ + spec = wav_to_spec(y, n_fft, hop_length, win_length, center=center) + return spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax) + + class TorchSTFT(nn.Module): # pylint: disable=abstract-method """Some of the audio processing funtions using Torch for faster batch processing. diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/modules/freevc/mel_processing.py index 4da5e27c83..017d900284 100644 --- a/TTS/vc/modules/freevc/mel_processing.py +++ b/TTS/vc/modules/freevc/mel_processing.py @@ -14,54 +14,6 @@ hann_window = {} -def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): - if torch.min(y) < -1.0: - logger.info("Min value is: %.3f", torch.min(y)) - if torch.max(y) > 1.0: - logger.info("Max value is: %.3f", torch.max(y)) - - global hann_window - dtype_device = str(y.dtype) + "_" + str(y.device) - wnsize_dtype_device = str(win_size) + "_" + dtype_device - if wnsize_dtype_device not in hann_window: - hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) - - y = torch.nn.functional.pad( - y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" - ) - y = y.squeeze(1) - - spec = torch.view_as_real( - torch.stft( - y, - n_fft, - hop_length=hop_size, - win_length=win_size, - window=hann_window[wnsize_dtype_device], - center=center, - pad_mode="reflect", - normalized=False, - onesided=True, - return_complex=True, - ) - ) - - spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) - return spec - - -def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): - global mel_basis - dtype_device = str(spec.dtype) + "_" + str(spec.device) - fmax_dtype_device = str(fmax) + "_" + dtype_device - if fmax_dtype_device not in mel_basis: - mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) - mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) - spec = torch.matmul(mel_basis[fmax_dtype_device], spec) - spec = amp_to_db(spec) - return spec - - def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): if torch.min(y) < -1.0: logger.info("Min value is: %.3f", torch.min(y)) diff --git a/tests/tts_tests/test_vits.py b/tests/tts_tests/test_vits.py index a27bdfe5b5..c8a52e1c1b 100644 --- a/tests/tts_tests/test_vits.py +++ b/tests/tts_tests/test_vits.py @@ -14,12 +14,9 @@ VitsArgs, VitsAudioConfig, load_audio, - spec_to_mel, - wav_to_mel, - wav_to_spec, ) from TTS.tts.utils.speakers import SpeakerManager -from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp +from TTS.utils.audio.torch_transforms import amp_to_db, db_to_amp, spec_to_mel, wav_to_mel, wav_to_spec LANG_FILE = os.path.join(get_tests_input_path(), "language_ids.json") SPEAKER_ENCODER_CONFIG = os.path.join(get_tests_input_path(), "test_speaker_encoder_config.json") From 8bf288eeab63adcf26d32125615628cc1abdaf31 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sun, 24 Nov 2024 15:37:04 +0100 Subject: [PATCH 22/25] test: move test_helpers.py to fast unit tests --- tests/{tts_tests => aux_tests}/test_helpers.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{tts_tests => aux_tests}/test_helpers.py (100%) diff --git a/tests/tts_tests/test_helpers.py b/tests/aux_tests/test_helpers.py similarity index 100% rename from tests/tts_tests/test_helpers.py rename to tests/aux_tests/test_helpers.py From 7330ad8854f2e72cee5b66e481188cb3750e28df Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sun, 24 Nov 2024 17:46:51 +0100 Subject: [PATCH 23/25] refactor: move duplicate alignment functions into helpers --- .../layers/delightful_tts/acoustic_model.py | 61 ++++------------ TTS/tts/layers/delightful_tts/encoders.py | 13 +--- TTS/tts/models/align_tts.py | 36 +--------- TTS/tts/models/forward_tts.py | 50 +------------- TTS/tts/utils/helpers.py | 69 +++++++++++++++++-- tests/aux_tests/test_helpers.py | 31 ++++++++- tests/tts_tests2/test_forward_tts.py | 24 +------ 7 files changed, 114 insertions(+), 170 deletions(-) diff --git a/TTS/tts/layers/delightful_tts/acoustic_model.py b/TTS/tts/layers/delightful_tts/acoustic_model.py index 3c0e3a3a76..981d6cdb1f 100644 --- a/TTS/tts/layers/delightful_tts/acoustic_model.py +++ b/TTS/tts/layers/delightful_tts/acoustic_model.py @@ -12,7 +12,6 @@ from TTS.tts.layers.delightful_tts.encoders import ( PhonemeLevelProsodyEncoder, UtteranceLevelProsodyEncoder, - get_mask_from_lengths, ) from TTS.tts.layers.delightful_tts.energy_adaptor import EnergyAdaptor from TTS.tts.layers.delightful_tts.networks import EmbeddingPadded, positional_encoding @@ -20,7 +19,7 @@ from TTS.tts.layers.delightful_tts.pitch_adaptor import PitchAdaptor from TTS.tts.layers.delightful_tts.variance_predictor import VariancePredictor from TTS.tts.layers.generic.aligner import AlignmentNetwork -from TTS.tts.utils.helpers import generate_path, sequence_mask +from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask logger = logging.getLogger(__name__) @@ -231,42 +230,6 @@ def _init_d_vector(self): raise ValueError("[!] Speaker embedding layer already initialized before d_vector settings.") self.embedded_speaker_dim = self.args.d_vector_dim - @staticmethod - def generate_attn(dr, x_mask, y_mask=None): - """Generate an attention mask from the linear scale durations. - - Args: - dr (Tensor): Linear scale durations. - x_mask (Tensor): Mask for the input (character) sequence. - y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations - if None. Defaults to None. - - Shapes - - dr: :math:`(B, T_{en})` - - x_mask: :math:`(B, T_{en})` - - y_mask: :math:`(B, T_{de})` - """ - # compute decode mask from the durations - if y_mask is None: - y_lengths = dr.sum(1).long() - y_lengths[y_lengths < 1] = 1 - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) - return attn - - def _expand_encoder_with_durations( - self, - o_en: torch.FloatTensor, - dr: torch.IntTensor, - x_mask: torch.IntTensor, - y_lengths: torch.IntTensor, - ): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) - attn = self.generate_attn(dr, x_mask, y_mask) - o_en_ex = torch.einsum("kmn, kjm -> kjn", [attn.float(), o_en]) - return y_mask, o_en_ex, attn.transpose(1, 2) - def _forward_aligner( self, x: torch.FloatTensor, @@ -340,8 +303,8 @@ def forward( {"d_vectors": d_vectors, "speaker_ids": speaker_idx} ) # pylint: disable=unused-variable - src_mask = get_mask_from_lengths(src_lens) # [B, T_src] - mel_mask = get_mask_from_lengths(mel_lens) # [B, T_mel] + src_mask = ~sequence_mask(src_lens) # [B, T_src] + mel_mask = ~sequence_mask(mel_lens) # [B, T_mel] # Token embeddings token_embeddings = self.src_word_emb(tokens) # [B, T_src, C_hidden] @@ -420,8 +383,8 @@ def forward( encoder_outputs = encoder_outputs.transpose(1, 2) + pitch_emb + energy_emb log_duration_prediction = self.duration_predictor(x=encoder_outputs_res.detach(), mask=src_mask) - mel_pred_mask, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( - o_en=encoder_outputs, y_lengths=mel_lens, dr=dr, x_mask=~src_mask[:, None] + encoder_outputs_ex, alignments, mel_pred_mask = expand_encoder_outputs( + encoder_outputs, y_lengths=mel_lens, duration=dr, x_mask=~src_mask[:, None] ) x = self.decoder( @@ -435,7 +398,7 @@ def forward( dr = torch.log(dr + 1) dr_pred = torch.exp(log_duration_prediction) - 1 - alignments_dp = self.generate_attn(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] + alignments_dp = generate_attention(dr_pred, src_mask.unsqueeze(1), mel_pred_mask) # [B, T_max, T_max2'] return { "model_outputs": x, @@ -448,7 +411,7 @@ def forward( "p_prosody_pred": p_prosody_pred, "p_prosody_ref": p_prosody_ref, "alignments_dp": alignments_dp, - "alignments": alignments, # [B, T_de, T_en] + "alignments": alignments.transpose(1, 2), # [B, T_de, T_en] "aligner_soft": aligner_soft, "aligner_mas": aligner_mas, "aligner_durations": aligner_durations, @@ -469,7 +432,7 @@ def inference( pitch_transform: Callable = None, energy_transform: Callable = None, ) -> torch.Tensor: - src_mask = get_mask_from_lengths(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) + src_mask = ~sequence_mask(torch.tensor([tokens.shape[1]], dtype=torch.int64, device=tokens.device)) src_lens = torch.tensor(tokens.shape[1:2]).to(tokens.device) # pylint: disable=unused-variable sid, g, lid, _ = self._set_cond_input( # pylint: disable=unused-variable {"d_vectors": d_vectors, "speaker_ids": speaker_idx} @@ -536,11 +499,11 @@ def inference( duration_pred = torch.round(duration_pred) # -> [B, T_src] mel_lens = duration_pred.sum(1) # -> [B,] - _, encoder_outputs_ex, alignments = self._expand_encoder_with_durations( - o_en=encoder_outputs, y_lengths=mel_lens, dr=duration_pred.squeeze(1), x_mask=~src_mask[:, None] + encoder_outputs_ex, alignments, _ = expand_encoder_outputs( + encoder_outputs, y_lengths=mel_lens, duration=duration_pred.squeeze(1), x_mask=~src_mask[:, None] ) - mel_mask = get_mask_from_lengths( + mel_mask = ~sequence_mask( torch.tensor([encoder_outputs_ex.shape[2]], dtype=torch.int64, device=encoder_outputs_ex.device) ) @@ -557,7 +520,7 @@ def inference( x = self.to_mel(x) outputs = { "model_outputs": x, - "alignments": alignments, + "alignments": alignments.transpose(1, 2), # "pitch": pitch_emb_pred, "durations": duration_pred, "pitch": pitch_pred, diff --git a/TTS/tts/layers/delightful_tts/encoders.py b/TTS/tts/layers/delightful_tts/encoders.py index 0878f0677a..bd0c319dc1 100644 --- a/TTS/tts/layers/delightful_tts/encoders.py +++ b/TTS/tts/layers/delightful_tts/encoders.py @@ -7,14 +7,7 @@ from TTS.tts.layers.delightful_tts.conformer import ConformerMultiHeadedSelfAttention from TTS.tts.layers.delightful_tts.conv_layers import CoordConv1d from TTS.tts.layers.delightful_tts.networks import STL - - -def get_mask_from_lengths(lengths: torch.Tensor) -> torch.Tensor: - batch_size = lengths.shape[0] - max_len = torch.max(lengths).item() - ids = torch.arange(0, max_len, device=lengths.device).unsqueeze(0).expand(batch_size, -1) - mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) - return mask +from TTS.tts.utils.helpers import sequence_mask def stride_lens(lens: torch.Tensor, stride: int = 2) -> torch.Tensor: @@ -93,7 +86,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor outputs --- [N, E//2] """ - mel_masks = get_mask_from_lengths(mel_lens).unsqueeze(1) + mel_masks = ~sequence_mask(mel_lens).unsqueeze(1) x = x.masked_fill(mel_masks, 0) for conv, norm in zip(self.convs, self.norms): x = conv(x) @@ -103,7 +96,7 @@ def forward(self, x: torch.Tensor, mel_lens: torch.Tensor) -> Tuple[torch.Tensor for _ in range(2): mel_lens = stride_lens(mel_lens) - mel_masks = get_mask_from_lengths(mel_lens) + mel_masks = ~sequence_mask(mel_lens) x = x.masked_fill(mel_masks.unsqueeze(1), 0) x = x.permute((0, 2, 1)) diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 1c3d57582e..28a52bc558 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -13,7 +13,7 @@ from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import generate_path, sequence_mask +from TTS.tts.utils.helpers import expand_encoder_outputs, generate_attention, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_spectrogram @@ -169,35 +169,6 @@ def compute_align_path(self, mu, log_sigma, y, x_mask, y_mask): dr_mas = torch.sum(attn, -1) return dr_mas.squeeze(1), log_p - @staticmethod - def generate_attn(dr, x_mask, y_mask=None): - # compute decode mask from the durations - if y_mask is None: - y_lengths = dr.sum(1).long() - y_lengths[y_lengths < 1] = 1 - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) - return attn - - def expand_encoder_outputs(self, en, dr, x_mask, y_mask): - """Generate attention alignment map from durations and - expand encoder outputs - - Examples:: - - encoder output: [a,b,c,d] - - durations: [1, 3, 2, 1] - - - expanded: [a, b, b, b, c, c, d] - - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] - """ - attn = self.generate_attn(dr, x_mask, y_mask) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) - return o_en_ex, attn - def format_durations(self, o_dr_log, x_mask): o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale o_dr[o_dr < 1] = 1.0 @@ -243,9 +214,8 @@ def _forward_encoder(self, x, x_lengths, g=None): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) @@ -282,7 +252,7 @@ def forward( o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) - attn = self.generate_attn(dr_mas, x_mask, y_mask) + attn = generate_attention(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) diff --git a/TTS/tts/models/forward_tts.py b/TTS/tts/models/forward_tts.py index d449e580da..d09e3ea91b 100644 --- a/TTS/tts/models/forward_tts.py +++ b/TTS/tts/models/forward_tts.py @@ -14,7 +14,7 @@ from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import average_over_durations, generate_path, sequence_mask +from TTS.tts.utils.helpers import average_over_durations, expand_encoder_outputs, generate_attention, sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment, plot_avg_energy, plot_avg_pitch, plot_spectrogram @@ -310,49 +310,6 @@ def init_multispeaker(self, config: Coqpit): self.emb_g = nn.Embedding(self.num_speakers, self.args.hidden_channels) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - @staticmethod - def generate_attn(dr, x_mask, y_mask=None): - """Generate an attention mask from the durations. - - Shapes - - dr: :math:`(B, T_{en})` - - x_mask: :math:`(B, T_{en})` - - y_mask: :math:`(B, T_{de})` - """ - # compute decode mask from the durations - if y_mask is None: - y_lengths = dr.sum(1).long() - y_lengths[y_lengths < 1] = 1 - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) - return attn - - def expand_encoder_outputs(self, en, dr, x_mask, y_mask): - """Generate attention alignment map from durations and - expand encoder outputs - - Shapes: - - en: :math:`(B, D_{en}, T_{en})` - - dr: :math:`(B, T_{en})` - - x_mask: :math:`(B, T_{en})` - - y_mask: :math:`(B, T_{de})` - - Examples:: - - encoder output: [a,b,c,d] - durations: [1, 3, 2, 1] - - expanded: [a, b, b, b, c, c, d] - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] - """ - attn = self.generate_attn(dr, x_mask, y_mask) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) - return o_en_ex, attn - def format_durations(self, o_dr_log, x_mask): """Format predicted durations. 1. Convert to linear scale from log scale @@ -443,9 +400,8 @@ def _forward_decoder( Returns: Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. """ - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + o_en_ex, attn, y_mask = expand_encoder_outputs(o_en, dr, x_mask, y_lengths) # positional encoding if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) @@ -624,7 +580,7 @@ def forward( o_dr_log = self.duration_predictor(o_en, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) # generate attn mask from predicted durations - o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + o_attn = generate_attention(o_dr.squeeze(1), x_mask) # aligner o_alignment_dur = None alignment_soft = None diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index d1722501f7..ff10f751f2 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import torch from scipy.stats import betabinom @@ -33,7 +35,7 @@ def inverse_transform(self, X): # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 -def sequence_mask(sequence_length, max_len=None): +def sequence_mask(sequence_length: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor: """Create a sequence mask for filtering padding in a sequence tensor. Args: @@ -44,7 +46,7 @@ def sequence_mask(sequence_length, max_len=None): - mask: :math:`[B, T_max]` """ if max_len is None: - max_len = sequence_length.max() + max_len = int(sequence_length.max()) seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) # B x T_max return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) @@ -143,22 +145,75 @@ def convert_pad_shape(pad_shape: list[list]) -> list: return [item for sublist in l for item in sublist] -def generate_path(duration, mask): - """ +def generate_path(duration: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Generate alignment path based on the given segment durations. + Shapes: - duration: :math:`[B, T_en]` - mask: :math:'[B, T_en, T_de]` - path: :math:`[B, T_en, T_de]` """ b, t_x, t_y = mask.shape - cum_duration = torch.cumsum(duration, 1) + cum_duration = torch.cumsum(duration, dim=1) cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] - path = path * mask - return path + return path * mask + + +def generate_attention( + duration: torch.Tensor, x_mask: torch.Tensor, y_mask: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Generate an attention map from the linear scale durations. + + Args: + duration (Tensor): Linear scale durations. + x_mask (Tensor): Mask for the input (character) sequence. + y_mask (Tensor): Mask for the output (spectrogram) sequence. Compute it from the predicted durations + if None. Defaults to None. + + Shapes + - duration: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = duration.sum(dim=1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = sequence_mask(y_lengths).unsqueeze(1).to(duration.dtype) + attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) + return generate_path(duration, attn_mask.squeeze(1)).to(duration.dtype) + + +def expand_encoder_outputs( + x: torch.Tensor, duration: torch.Tensor, x_mask: torch.Tensor, y_lengths: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Generate attention alignment map from durations and expand encoder outputs. + + Shapes: + - x: Encoder output :math:`(B, D_{en}, T_{en})` + - duration: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_lengths: :math:`(B)` + + Examples:: + + encoder output: [a,b,c,d] + durations: [1, 3, 2, 1] + + expanded: [a, b, b, b, c, c, d] + attention map: [[0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 1, 1, 0], + [0, 1, 1, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0]] + """ + y_mask = sequence_mask(y_lengths).unsqueeze(1).to(x.dtype) + attn = generate_attention(duration, x_mask, y_mask) + x_expanded = torch.einsum("kmn, kjm -> kjn", [attn.float(), x]) + return x_expanded, attn, y_mask def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling_factor=1.0): diff --git a/tests/aux_tests/test_helpers.py b/tests/aux_tests/test_helpers.py index d07efa3620..6781cbc5d4 100644 --- a/tests/aux_tests/test_helpers.py +++ b/tests/aux_tests/test_helpers.py @@ -1,6 +1,14 @@ import torch as T -from TTS.tts.utils.helpers import average_over_durations, generate_path, rand_segments, segment, sequence_mask +from TTS.tts.utils.helpers import ( + average_over_durations, + expand_encoder_outputs, + generate_attention, + generate_path, + rand_segments, + segment, + sequence_mask, +) def test_average_over_durations(): # pylint: disable=no-self-use @@ -86,3 +94,24 @@ def test_generate_path(): assert all(path[b, t, :current_idx] == 0.0) assert all(path[b, t, current_idx + durations[b, t].item() :] == 0.0) current_idx += durations[b, t].item() + + assert T.all(path == generate_attention(durations, x_mask, y_mask)) + assert T.all(path == generate_attention(durations, x_mask)) + + +def test_expand_encoder_outputs(): + inputs = T.rand(2, 5, 57) + durations = T.randint(1, 4, (2, 57)) + + x_mask = T.ones(2, 1, 57) + y_lengths = T.ones(2) * durations.sum(1).max() + + expanded, _, _ = expand_encoder_outputs(inputs, durations, x_mask, y_lengths) + + for b in range(durations.shape[0]): + index = 0 + for idx, dur in enumerate(durations[b]): + idx_expanded = expanded[b, :, index : index + dur.item()] + diff = (idx_expanded - inputs[b, :, idx].repeat(int(dur)).view(idx_expanded.shape)).sum() + assert abs(diff) < 1e-6, diff + index += dur diff --git a/tests/tts_tests2/test_forward_tts.py b/tests/tts_tests2/test_forward_tts.py index cec0f211c8..13a2c270af 100644 --- a/tests/tts_tests2/test_forward_tts.py +++ b/tests/tts_tests2/test_forward_tts.py @@ -6,29 +6,7 @@ # pylint: disable=unused-variable -def expand_encoder_outputs_test(): - model = ForwardTTS(ForwardTTSArgs(num_chars=10)) - - inputs = T.rand(2, 5, 57) - durations = T.randint(1, 4, (2, 57)) - - x_mask = T.ones(2, 1, 57) - y_mask = T.ones(2, 1, durations.sum(1).max()) - - expanded, _ = model.expand_encoder_outputs(inputs, durations, x_mask, y_mask) - - for b in range(durations.shape[0]): - index = 0 - for idx, dur in enumerate(durations[b]): - diff = ( - expanded[b, :, index : index + dur.item()] - - inputs[b, :, idx].repeat(dur.item()).view(expanded[b, :, index : index + dur.item()].shape) - ).sum() - assert abs(diff) < 1e-6, diff - index += dur - - -def model_input_output_test(): +def test_model_input_output(): """Assert the output shapes of the model in different modes""" # VANILLA MODEL From 170d3dae92641aacf99827358e3fadbc3b7436ea Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Sun, 24 Nov 2024 19:36:45 +0100 Subject: [PATCH 24/25] refactor: remove duplicate to_camel --- TTS/vc/models/__init__.py | 5 ----- TTS/vocoder/models/__init__.py | 7 ++----- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/TTS/vc/models/__init__.py b/TTS/vc/models/__init__.py index a498b292b7..a9807d7006 100644 --- a/TTS/vc/models/__init__.py +++ b/TTS/vc/models/__init__.py @@ -6,11 +6,6 @@ logger = logging.getLogger(__name__) -def to_camel(text): - text = text.capitalize() - return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) - - def setup_model(config: "Coqpit", samples: Union[List[List], List[Dict]] = None) -> "BaseVC": logger.info("Using model: %s", config.model) # fetch the right model implementation. diff --git a/TTS/vocoder/models/__init__.py b/TTS/vocoder/models/__init__.py index 7a1716f16d..b6a1850484 100644 --- a/TTS/vocoder/models/__init__.py +++ b/TTS/vocoder/models/__init__.py @@ -4,12 +4,9 @@ from coqpit import Coqpit -logger = logging.getLogger(__name__) - +from TTS.utils.generic_utils import to_camel -def to_camel(text): - text = text.capitalize() - return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) +logger = logging.getLogger(__name__) def setup_model(config: Coqpit): From 63625e79af1e13928474fdf964a3322273542939 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Wed, 27 Nov 2024 16:12:38 +0100 Subject: [PATCH 25/25] refactor: import get_last_checkpoint from trainer.io --- TTS/bin/compute_attention_masks.py | 2 +- TTS/encoder/utils/training.py | 4 ++-- tests/tts_tests/test_neuralhmm_tts_train.py | 2 +- tests/tts_tests/test_overflow_train.py | 2 +- tests/tts_tests/test_speedy_speech_train.py | 2 +- tests/tts_tests/test_tacotron2_d-vectors_train.py | 2 +- tests/tts_tests/test_tacotron2_speaker_emb_train.py | 2 +- tests/tts_tests/test_tacotron2_train.py | 2 +- tests/tts_tests/test_tacotron_train.py | 2 +- tests/tts_tests/test_vits_multilingual_speaker_emb_train.py | 2 +- tests/tts_tests/test_vits_multilingual_train-d_vectors.py | 2 +- tests/tts_tests/test_vits_speaker_emb_train.py | 2 +- tests/tts_tests/test_vits_train.py | 2 +- tests/tts_tests2/test_align_tts_train.py | 2 +- tests/tts_tests2/test_delightful_tts_d-vectors_train.py | 2 +- tests/tts_tests2/test_delightful_tts_emb_spk.py | 2 +- tests/tts_tests2/test_delightful_tts_train.py | 2 +- tests/tts_tests2/test_fast_pitch_speaker_emb_train.py | 2 +- tests/tts_tests2/test_fast_pitch_train.py | 2 +- tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py | 2 +- tests/tts_tests2/test_fastspeech_2_train.py | 2 +- tests/tts_tests2/test_glow_tts_d-vectors_train.py | 2 +- tests/tts_tests2/test_glow_tts_speaker_emb_train.py | 2 +- tests/tts_tests2/test_glow_tts_train.py | 2 +- 24 files changed, 25 insertions(+), 25 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 127199186b..535182d214 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -80,7 +80,7 @@ num_chars = len(phonemes) if C.use_phonemes else len(symbols) # TODO: handle multi-speaker model = setup_model(C) - model, _ = load_checkpoint(model, args.model_path, args.use_cuda, True) + model, _ = load_checkpoint(model, args.model_path, use_cuda=args.use_cuda, eval=True) # data loader preprocessor = importlib.import_module("TTS.tts.datasets.formatters") diff --git a/TTS/encoder/utils/training.py b/TTS/encoder/utils/training.py index cc3a78b084..48629c7a57 100644 --- a/TTS/encoder/utils/training.py +++ b/TTS/encoder/utils/training.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field from coqpit import Coqpit -from trainer import TrainerArgs, get_last_checkpoint +from trainer import TrainerArgs from trainer.generic_utils import get_experiment_folder_path, get_git_branch -from trainer.io import copy_model_files +from trainer.io import copy_model_files, get_last_checkpoint from trainer.logging import logger_factory from trainer.logging.console_logger import ConsoleLogger diff --git a/tests/tts_tests/test_neuralhmm_tts_train.py b/tests/tts_tests/test_neuralhmm_tts_train.py index 25d9aa8148..4789d53d9e 100644 --- a/tests/tts_tests/test_neuralhmm_tts_train.py +++ b/tests/tts_tests/test_neuralhmm_tts_train.py @@ -4,7 +4,7 @@ import shutil import torch -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.neuralhmm_tts_config import NeuralhmmTTSConfig diff --git a/tests/tts_tests/test_overflow_train.py b/tests/tts_tests/test_overflow_train.py index 86fa60af72..d86bde6854 100644 --- a/tests/tts_tests/test_overflow_train.py +++ b/tests/tts_tests/test_overflow_train.py @@ -4,7 +4,7 @@ import shutil import torch -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.overflow_config import OverflowConfig diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index 530781ef88..2aac7f101d 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index 99ba4349c4..d2d1d5c35f 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 5f1bc3fd50..83a07d1a6c 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index 40107070e1..df0e934d8e 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index f7751931ae..17f1fd46a6 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -2,7 +2,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron_config import TacotronConfig diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index 71597ef32f..09df7d29f2 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index fd58db534a..7ae09c0e5c 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index b7fe197cfe..69fae21f8d 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index ea5dc02405..78f42d154b 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig diff --git a/tests/tts_tests2/test_align_tts_train.py b/tests/tts_tests2/test_align_tts_train.py index 9b0b730df4..91c3c35bc6 100644 --- a/tests/tts_tests2/test_align_tts_train.py +++ b/tests/tts_tests2/test_align_tts_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.align_tts_config import AlignTTSConfig diff --git a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py index 8fc4ea7e9b..1e5cd49f73 100644 --- a/tests/tts_tests2/test_delightful_tts_d-vectors_train.py +++ b/tests/tts_tests2/test_delightful_tts_d-vectors_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig diff --git a/tests/tts_tests2/test_delightful_tts_emb_spk.py b/tests/tts_tests2/test_delightful_tts_emb_spk.py index 6fb70c5f61..9bbf7a55ea 100644 --- a/tests/tts_tests2/test_delightful_tts_emb_spk.py +++ b/tests/tts_tests2/test_delightful_tts_emb_spk.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.delightful_tts_config import DelightfulTtsAudioConfig, DelightfulTTSConfig diff --git a/tests/tts_tests2/test_delightful_tts_train.py b/tests/tts_tests2/test_delightful_tts_train.py index a917d77657..3e6fbd2e86 100644 --- a/tests/tts_tests2/test_delightful_tts_train.py +++ b/tests/tts_tests2/test_delightful_tts_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py index 7f79bfcab2..e6bc9f9feb 100644 --- a/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests2/test_fast_pitch_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fast_pitch_train.py b/tests/tts_tests2/test_fast_pitch_train.py index a525715b53..fe87c8b600 100644 --- a/tests/tts_tests2/test_fast_pitch_train.py +++ b/tests/tts_tests2/test_fast_pitch_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py b/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py index 35bda597d5..735d2fc4c6 100644 --- a/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py +++ b/tests/tts_tests2/test_fastspeech_2_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_fastspeech_2_train.py b/tests/tts_tests2/test_fastspeech_2_train.py index dd4b07d240..07fc5a1a2c 100644 --- a/tests/tts_tests2/test_fastspeech_2_train.py +++ b/tests/tts_tests2/test_fastspeech_2_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig diff --git a/tests/tts_tests2/test_glow_tts_d-vectors_train.py b/tests/tts_tests2/test_glow_tts_d-vectors_train.py index f1cfd4368f..8236607c25 100644 --- a/tests/tts_tests2/test_glow_tts_d-vectors_train.py +++ b/tests/tts_tests2/test_glow_tts_d-vectors_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig diff --git a/tests/tts_tests2/test_glow_tts_speaker_emb_train.py b/tests/tts_tests2/test_glow_tts_speaker_emb_train.py index b1eb6237a4..4a8bd0658d 100644 --- a/tests/tts_tests2/test_glow_tts_speaker_emb_train.py +++ b/tests/tts_tests2/test_glow_tts_speaker_emb_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig diff --git a/tests/tts_tests2/test_glow_tts_train.py b/tests/tts_tests2/test_glow_tts_train.py index 0a8e226b65..1d7f913575 100644 --- a/tests/tts_tests2/test_glow_tts_train.py +++ b/tests/tts_tests2/test_glow_tts_train.py @@ -3,7 +3,7 @@ import os import shutil -from trainer import get_last_checkpoint +from trainer.io import get_last_checkpoint from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig