diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 37e1bc3204..0876cb47f2 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -56,7 +56,7 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - ./matcha/inference.py \ + ./matcha/infer.py \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ diff --git a/egs/ljspeech/TTS/README.md b/egs/ljspeech/TTS/README.md index 1cd6e8fd73..82850cd04f 100644 --- a/egs/ljspeech/TTS/README.md +++ b/egs/ljspeech/TTS/README.md @@ -131,12 +131,12 @@ To inference, use: wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 -./matcha/inference \ +./matcha/synth.py \ --exp-dir ./matcha/exp-new-3 \ --epoch 4000 \ --tokens ./data/tokens.txt \ --vocoder ./generator_v1 \ - --input-text "how are you doing?" + --input-text "how are you doing?" \ --output-wav ./generated.wav ``` diff --git a/egs/ljspeech/TTS/local/audio.py b/egs/ljspeech/TTS/local/audio.py new file mode 120000 index 0000000000..b70d91c920 --- /dev/null +++ b/egs/ljspeech/TTS/local/audio.py @@ -0,0 +1 @@ +../matcha/audio.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py index 5152ae675c..296f9a4f42 100755 --- a/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py +++ b/egs/ljspeech/TTS/local/compute_fbank_ljspeech.py @@ -27,102 +27,17 @@ import argparse import logging import os -from dataclasses import dataclass from pathlib import Path -from typing import Union -import numpy as np import torch +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, LilcomChunkyWriter, load_manifest from lhotse.audio import RecordingSet -from lhotse.features.base import FeatureExtractor, register_extractor from lhotse.supervision import SupervisionSet -from lhotse.utils import Seconds, compute_num_frames -from matcha.audio import mel_spectrogram from icefall.utils import get_executor -@dataclass -class MyFbankConfig: - n_fft: int - n_mels: int - sampling_rate: int - hop_length: int - win_length: int - f_min: float - f_max: float - - -@register_extractor -class MyFbank(FeatureExtractor): - - name = "MyFbank" - config_type = MyFbankConfig - - def __init__(self, config): - super().__init__(config=config) - - @property - def device(self) -> Union[str, torch.device]: - return self.config.device - - def feature_dim(self, sampling_rate: int) -> int: - return self.config.n_mels - - def extract( - self, - samples: np.ndarray, - sampling_rate: int, - ) -> torch.Tensor: - # Check for sampling rate compatibility. - expected_sr = self.config.sampling_rate - assert sampling_rate == expected_sr, ( - f"Mismatched sampling rate: extractor expects {expected_sr}, " - f"got {sampling_rate}" - ) - samples = torch.from_numpy(samples) - assert samples.ndim == 2, samples.shape - assert samples.shape[0] == 1, samples.shape - - mel = ( - mel_spectrogram( - samples, - self.config.n_fft, - self.config.n_mels, - self.config.sampling_rate, - self.config.hop_length, - self.config.win_length, - self.config.f_min, - self.config.f_max, - center=False, - ) - .squeeze() - .t() - ) - - assert mel.ndim == 2, mel.shape - assert mel.shape[1] == self.config.n_mels, mel.shape - - num_frames = compute_num_frames( - samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate - ) - - if mel.shape[0] > num_frames: - mel = mel[:num_frames] - elif mel.shape[0] < num_frames: - mel = mel.unsqueeze(0) - mel = torch.nn.functional.pad( - mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" - ).squeeze(0) - - return mel.numpy() - - @property - def frame_shift(self) -> Seconds: - return self.config.hop_length / self.config.sampling_rate - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int): logging.info(f"num_jobs: {num_jobs}") logging.info(f"src_dir: {src_dir}") logging.info(f"output_dir: {output_dir}") - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=22050, @@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int): src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet ) - extractor = MyFbank(config) + extractor = MatchaFbank(config) with get_executor() as ex: # Initialize the executor only once. cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" diff --git a/egs/ljspeech/TTS/local/fbank.py b/egs/ljspeech/TTS/local/fbank.py new file mode 120000 index 0000000000..5bcf1fde57 --- /dev/null +++ b/egs/ljspeech/TTS/local/fbank.py @@ -0,0 +1 @@ +../matcha/fbank.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py b/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py deleted file mode 120000 index 85255ba0c0..0000000000 --- a/egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py +++ /dev/null @@ -1 +0,0 @@ -../local/compute_fbank_ljspeech.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py index 63d1fac205..5c96b3bc79 100755 --- a/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py +++ b/egs/ljspeech/TTS/matcha/export_onnx_hifigan.py @@ -7,7 +7,7 @@ import onnx import torch -from inference import load_vocoder +from infer import load_vocoder def add_meta_data(filename: str, meta_data: Dict[str, Any]): diff --git a/egs/ljspeech/TTS/matcha/fbank.py b/egs/ljspeech/TTS/matcha/fbank.py new file mode 100644 index 0000000000..d729fa425f --- /dev/null +++ b/egs/ljspeech/TTS/matcha/fbank.py @@ -0,0 +1,88 @@ +from dataclasses import dataclass +from typing import Union + +import numpy as np +import torch +from audio import mel_spectrogram +from lhotse.features.base import FeatureExtractor, register_extractor +from lhotse.utils import Seconds, compute_num_frames + + +@dataclass +class MatchaFbankConfig: + n_fft: int + n_mels: int + sampling_rate: int + hop_length: int + win_length: int + f_min: float + f_max: float + + +@register_extractor +class MatchaFbank(FeatureExtractor): + + name = "MatchaFbank" + config_type = MatchaFbankConfig + + def __init__(self, config): + super().__init__(config=config) + + @property + def device(self) -> Union[str, torch.device]: + return self.config.device + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.n_mels + + def extract( + self, + samples: np.ndarray, + sampling_rate: int, + ) -> torch.Tensor: + # Check for sampling rate compatibility. + expected_sr = self.config.sampling_rate + assert sampling_rate == expected_sr, ( + f"Mismatched sampling rate: extractor expects {expected_sr}, " + f"got {sampling_rate}" + ) + samples = torch.from_numpy(samples) + assert samples.ndim == 2, samples.shape + assert samples.shape[0] == 1, samples.shape + + mel = ( + mel_spectrogram( + samples, + self.config.n_fft, + self.config.n_mels, + self.config.sampling_rate, + self.config.hop_length, + self.config.win_length, + self.config.f_min, + self.config.f_max, + center=False, + ) + .squeeze() + .t() + ) + + assert mel.ndim == 2, mel.shape + assert mel.shape[1] == self.config.n_mels, mel.shape + + num_frames = compute_num_frames( + samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate + ) + + if mel.shape[0] > num_frames: + mel = mel[:num_frames] + elif mel.shape[0] < num_frames: + mel = mel.unsqueeze(0) + mel = torch.nn.functional.pad( + mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate" + ).squeeze(0) + + return mel.numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.hop_length / self.config.sampling_rate diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py new file mode 100755 index 0000000000..0b221d5c55 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", + ) + + return parser + + +def load_vocoder(checkpoint_path: Path) -> nn.Module: + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.squeeze() + + +def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long, device=device) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device=device) + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesize( + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + device: str = "cpu", + spks=None, +) -> dict: + text_processed = process_text(text=text, tokenizer=tokenizer, device=device) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = [c.supervisions[0].normalized_text for c in batch["cut"]] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + sf.write( + file=params.save_wav_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.data_args.sampling_rate, + subtype="PCM_16", + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.inference_mode() +def main(): + parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + model.eval() + + # we need cut ids to organize tts results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + vocoder.to(device) + + denoiser = Denoiser(vocoder, mode="zeros") + denoiser.to(device) + + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) + + +if __name__ == "__main__": + main() diff --git a/egs/ljspeech/TTS/matcha/inference.py b/egs/ljspeech/TTS/matcha/inference.py deleted file mode 100755 index 64abd8e50b..0000000000 --- a/egs/ljspeech/TTS/matcha/inference.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import argparse -import datetime as dt -import json -import logging -from pathlib import Path - -import soundfile as sf -import torch -from matcha.hifigan.config import v1, v2, v3 -from matcha.hifigan.denoiser import Denoiser -from matcha.hifigan.models import Generator as HiFiGAN -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp-new-3", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--vocoder", - type=Path, - default="./generator_v1", - help="Path to the vocoder", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - - return parser - - -def load_vocoder(checkpoint_path): - checkpoint_path = str(checkpoint_path) - if checkpoint_path.endswith("v1"): - h = AttributeDict(v1) - elif checkpoint_path.endswith("v2"): - h = AttributeDict(v2) - elif checkpoint_path.endswith("v3"): - h = AttributeDict(v3) - else: - raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") - - hifigan = HiFiGAN(h).to("cpu") - hifigan.load_state_dict( - torch.load(checkpoint_path, map_location="cpu")["generator"] - ) - _ = hifigan.eval() - hifigan.remove_weight_norm() - return hifigan - - -def to_waveform(mel, vocoder, denoiser): - audio = vocoder(mel).clamp(-1, 1) - audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() - return audio.cpu().squeeze() - - -def process_text(text: str, tokenizer): - x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) - x = torch.tensor(x, dtype=torch.long) - x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") - return {"x_orig": text, "x": x, "x_lengths": x_lengths} - - -def synthesise( - model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None -): - text_processed = process_text(text, tokenizer) - start_t = dt.datetime.now() - output = model.synthesise( - text_processed["x"], - text_processed["x_lengths"], - n_timesteps=n_timesteps, - temperature=temperature, - spks=spks, - length_scale=length_scale, - ) - # merge everything to one dict - output.update({"start_t": start_t, **text_processed}) - return output - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - - params.update(vars(args)) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): - raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.eval() - - if not Path(params.vocoder).is_file(): - raise ValueError(f"{params.vocoder} does not exist") - - vocoder = load_vocoder(params.vocoder) - denoiser = Denoiser(vocoder, mode="zeros") - - # Number of ODE Solver steps - n_timesteps = 2 - - # Changes to the speaking rate - length_scale = 1.0 - - # Sampling temperature - temperature = 0.667 - - output = synthesise( - model=model, - tokenizer=tokenizer, - n_timesteps=n_timesteps, - text=params.input_text, - length_scale=length_scale, - temperature=temperature, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main() diff --git a/egs/ljspeech/TTS/matcha/models/components/decoder.py b/egs/ljspeech/TTS/matcha/models/components/decoder.py index 14d19f5d4e..102d87713c 100644 --- a/egs/ljspeech/TTS/matcha/models/components/decoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/decoder.py @@ -7,7 +7,7 @@ from conformer import ConformerBlock from diffusers.models.activations import get_activation from einops import pack, rearrange, repeat -from matcha.models.components.transformer import BasicTransformerBlock +from models.components.transformer import BasicTransformerBlock class SinusoidalPosEmb(torch.nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py index 997689b1cb..eb795ef32d 100644 --- a/egs/ljspeech/TTS/matcha/models/components/flow_matching.py +++ b/egs/ljspeech/TTS/matcha/models/components/flow_matching.py @@ -2,7 +2,7 @@ import torch import torch.nn.functional as F -from matcha.models.components.decoder import Decoder +from models.components.decoder import Decoder class BASECFM(torch.nn.Module, ABC): diff --git a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py index ca77cba51c..364ff19381 100644 --- a/egs/ljspeech/TTS/matcha/models/components/text_encoder.py +++ b/egs/ljspeech/TTS/matcha/models/components/text_encoder.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from einops import rearrange -from matcha.model import sequence_mask +from model import sequence_mask class LayerNorm(nn.Module): diff --git a/egs/ljspeech/TTS/matcha/models/matcha_tts.py b/egs/ljspeech/TTS/matcha/models/matcha_tts.py index 330d1dc472..fe0a72402d 100644 --- a/egs/ljspeech/TTS/matcha/models/matcha_tts.py +++ b/egs/ljspeech/TTS/matcha/models/matcha_tts.py @@ -2,17 +2,17 @@ import math import random -import matcha.monotonic_align as monotonic_align +import monotonic_align as monotonic_align import torch -from matcha.model import ( +from model import ( denormalize, duration_loss, fix_len_compatibility, generate_path, sequence_mask, ) -from matcha.models.components.flow_matching import CFM -from matcha.models.components.text_encoder import TextEncoder +from models.components.flow_matching import CFM +from models.components.text_encoder import TextEncoder class MatchaTTS(torch.nn.Module): # 🍵 diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore index 28bdad6b84..3def4ae263 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore +++ b/egs/ljspeech/TTS/matcha/monotonic_align/.gitignore @@ -1,3 +1,3 @@ build core.c -*.so +*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py index 5b26fe4743..f87ae1bd3d 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/__init__.py @@ -1,8 +1,7 @@ -# Copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/__init__.py import numpy as np import torch -from matcha.monotonic_align.core import maximum_path_c + +from .core import maximum_path_c def maximum_path(value, mask): diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx index eabc7f2736..091fcc3a50 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx +++ b/egs/ljspeech/TTS/matcha/monotonic_align/core.pyx @@ -1,5 +1,3 @@ -# Copied from -# https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/core.pyx import numpy as np cimport cython diff --git a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py index df26c633e0..beacf2e369 100644 --- a/egs/ljspeech/TTS/matcha/monotonic_align/setup.py +++ b/egs/ljspeech/TTS/matcha/monotonic_align/setup.py @@ -1,12 +1,30 @@ -# Copied from +# Modified from # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/utils/monotonic_align/setup.py -from distutils.core import setup - -import numpy from Cython.Build import cythonize +from setuptools import Extension, setup +from setuptools.command.build_ext import build_ext as _build_ext + + +class build_ext(_build_ext): + """Overwrite build_ext.""" + + def finalize_options(self): + """Prevent numpy from thinking it is still in its setup process.""" + _build_ext.finalize_options(self) + __builtins__.__NUMPY_SETUP__ = False + import numpy + + self.include_dirs.append(numpy.get_include()) + +exts = [ + Extension( + name="core", + sources=["core.pyx"], + ) +] setup( name="monotonic_align", - ext_modules=cythonize("core.pyx"), - include_dirs=[numpy.get_include()], + ext_modules=cythonize(exts, language_level=3), + cmdclass={"build_ext": build_ext}, ) diff --git a/egs/ljspeech/TTS/matcha/requirements.txt b/egs/ljspeech/TTS/matcha/requirements.txt index 5aadc89844..d7829c1e12 100644 --- a/egs/ljspeech/TTS/matcha/requirements.txt +++ b/egs/ljspeech/TTS/matcha/requirements.txt @@ -1,3 +1,4 @@ conformer==0.3.2 diffusers # developed using version ==0.25.0 librosa +einops \ No newline at end of file diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 5e713fdfdb..31135f623b 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -14,9 +14,9 @@ import torch.multiprocessing as mp import torch.nn as nn from lhotse.utils import fix_random_seed -from matcha.model import fix_len_compatibility -from matcha.models.matcha_tts import MatchaTTS -from matcha.tokenizer import Tokenizer +from model import fix_len_compatibility +from models.matcha_tts import MatchaTTS +from tokenizer import Tokenizer from torch.cuda.amp import GradScaler, autocast from torch.nn.parallel import DistributedDataParallel as DDP from torch.optim import Optimizer @@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict: "n_spks": 1, "n_fft": 1024, "n_feats": 80, - "sample_rate": 22050, + "sampling_rate": 22050, "hop_length": 256, "win_length": 1024, "f_min": 0, @@ -445,11 +445,6 @@ def train_one_epoch( saved_bad_model = False - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - def save_bad_model(suffix: str = ""): save_checkpoint( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", diff --git a/egs/ljspeech/TTS/matcha/tts_datamodule.py b/egs/ljspeech/TTS/matcha/tts_datamodule.py index 8e37fc0308..1e637b766a 100644 --- a/egs/ljspeech/TTS/matcha/tts_datamodule.py +++ b/egs/ljspeech/TTS/matcha/tts_datamodule.py @@ -24,7 +24,7 @@ from typing import Any, Dict, Optional import torch -from compute_fbank_ljspeech import MyFbank, MyFbankConfig +from fbank import MatchaFbank, MatchaFbankConfig from lhotse import CutSet, load_manifest_lazy from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures CutConcatenate, @@ -32,7 +32,6 @@ DynamicBucketingSampler, PrecomputedFeatures, SimpleCutSampler, - SpecAugment, SpeechSynthesisDataset, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples @@ -177,7 +176,7 @@ def train_dataloaders( if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -189,7 +188,7 @@ def train_dataloaders( train = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) @@ -238,7 +237,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: logging.info("About to create dev dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -250,7 +249,7 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: validate = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: @@ -282,7 +281,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.info("About to create test dataset") if self.args.on_the_fly_feats: sampling_rate = 22050 - config = MyFbankConfig( + config = MatchaFbankConfig( n_fft=1024, n_mels=80, sampling_rate=sampling_rate, @@ -294,7 +293,7 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: test = SpeechSynthesisDataset( return_text=False, return_tokens=True, - feature_input_strategy=OnTheFlyFeatures(MyFbank(config)), + feature_input_strategy=OnTheFlyFeatures(MatchaFbank(config)), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/ljspeech/TTS/prepare.sh b/egs/ljspeech/TTS/prepare.sh index 6f16f8d473..ec5062933e 100755 --- a/egs/ljspeech/TTS/prepare.sh +++ b/egs/ljspeech/TTS/prepare.sh @@ -25,26 +25,16 @@ log() { log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: build monotonic_align lib" - if [ ! -d vits/monotonic_align/build ]; then - cd vits/monotonic_align - python3 setup.py build_ext --inplace - cd ../../ - else - log "monotonic_align lib for vits already built" - fi - - if [ ! -f ./matcha/monotonic_align/core.cpython-38-x86_64-linux-gnu.so ]; then - pushd matcha/monotonic_align - python3 setup.py build - mv -v build/lib.*/matcha/monotonic_align/core.*.so . - rm -rf build - rm core.c - ls -lh - popd - else - log "monotonic_align lib for matcha-tts already built" - fi + log "Stage -1: build monotonic_align lib (used by vits and matcha recipes)" + for recipe in vits matcha; do + if [ ! -d $recipe/monotonic_align/build ]; then + cd $recipe/monotonic_align + python3 setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib for $recipe already built" + fi + done fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 7be76e3151..cf1067dfcc 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -234,7 +234,7 @@ def main(): logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - # we need cut ids to display recognition results. + # we need cut ids to organize tts results. args.return_cuts = True ljspeech = LJSpeechTtsDataModule(args) diff --git a/egs/ljspeech/TTS/vits/monotonic_align/.gitignore b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore new file mode 100644 index 0000000000..3def4ae263 --- /dev/null +++ b/egs/ljspeech/TTS/vits/monotonic_align/.gitignore @@ -0,0 +1,3 @@ +build +core.c +*.so \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py index 1de10f012b..4faaa96a54 100755 --- a/egs/ljspeech/TTS/vits/test_model.py +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -18,7 +18,6 @@ from tokenizer import Tokenizer from train import get_model, get_params -from vits import VITS def test_model_type(model_type):