From 66225fbe3323c3fec34c14dbe2dc920c39bc87cf Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 1 Nov 2024 15:33:13 +0800 Subject: [PATCH] VITS recipe for LibriTTS corpus (#1776) --- README.md | 3 + egs/libritts/CODEC/encodec/train.py | 18 +- egs/libritts/CODEC/prepare.sh | 5 +- egs/libritts/TTS/README.md | 51 + .../TTS/local/compute_spectrogram_libritts.py | 1 + egs/libritts/TTS/local/prepare_token_file.py | 1 + .../TTS/local/prepare_tokens_libritts.py | 89 ++ egs/libritts/TTS/local/validate_manifest.py | 1 + egs/libritts/TTS/prepare.sh | 134 +++ egs/libritts/TTS/shared | 1 + egs/libritts/TTS/vits/duration_predictor.py | 1 + egs/libritts/TTS/vits/flow.py | 1 + egs/libritts/TTS/vits/generator.py | 1 + egs/libritts/TTS/vits/hifigan.py | 1 + egs/libritts/TTS/vits/infer.py | 280 +++++ egs/libritts/TTS/vits/loss.py | 1 + egs/libritts/TTS/vits/monotonic_align | 1 + egs/libritts/TTS/vits/posterior_encoder.py | 1 + egs/libritts/TTS/vits/residual_coupling.py | 1 + egs/libritts/TTS/vits/test_onnx.py | 141 +++ egs/libritts/TTS/vits/text_encoder.py | 1 + egs/libritts/TTS/vits/tokenizer.py | 1 + egs/libritts/TTS/vits/train.py | 1015 +++++++++++++++++ egs/libritts/TTS/vits/transform.py | 1 + egs/libritts/TTS/vits/tts_datamodule.py | 432 +++++++ egs/libritts/TTS/vits/utils.py | 1 + egs/libritts/TTS/vits/vits.py | 1 + egs/libritts/TTS/vits/wavenet.py | 1 + egs/ljspeech/TTS/vits/generator.py | 7 +- egs/ljspeech/TTS/vits/train.py | 4 +- egs/ljspeech/TTS/vits/vits.py | 6 + egs/vctk/TTS/vits/train.py | 4 +- 32 files changed, 2190 insertions(+), 17 deletions(-) create mode 100644 egs/libritts/TTS/README.md create mode 120000 egs/libritts/TTS/local/compute_spectrogram_libritts.py create mode 120000 egs/libritts/TTS/local/prepare_token_file.py create mode 100755 egs/libritts/TTS/local/prepare_tokens_libritts.py create mode 120000 egs/libritts/TTS/local/validate_manifest.py create mode 100755 egs/libritts/TTS/prepare.sh create mode 120000 egs/libritts/TTS/shared create mode 120000 egs/libritts/TTS/vits/duration_predictor.py create mode 120000 egs/libritts/TTS/vits/flow.py create mode 120000 egs/libritts/TTS/vits/generator.py create mode 120000 egs/libritts/TTS/vits/hifigan.py create mode 100755 egs/libritts/TTS/vits/infer.py create mode 120000 egs/libritts/TTS/vits/loss.py create mode 120000 egs/libritts/TTS/vits/monotonic_align create mode 120000 egs/libritts/TTS/vits/posterior_encoder.py create mode 120000 egs/libritts/TTS/vits/residual_coupling.py create mode 100755 egs/libritts/TTS/vits/test_onnx.py create mode 120000 egs/libritts/TTS/vits/text_encoder.py create mode 120000 egs/libritts/TTS/vits/tokenizer.py create mode 100755 egs/libritts/TTS/vits/train.py create mode 120000 egs/libritts/TTS/vits/transform.py create mode 100644 egs/libritts/TTS/vits/tts_datamodule.py create mode 120000 egs/libritts/TTS/vits/utils.py create mode 120000 egs/libritts/TTS/vits/vits.py create mode 120000 egs/libritts/TTS/vits/wavenet.py diff --git a/README.md b/README.md index 57db5eb8db..0e550ffb12 100644 --- a/README.md +++ b/README.md @@ -333,6 +333,7 @@ We provide a Colab notebook to test the pre-trained model: [![Open In Colab](htt - [LJSpeech][ljspeech] - [VCTK][vctk] + - [LibriTTS][libritts_tts] ### Supported Models @@ -372,6 +373,7 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [commonvoice]: egs/commonvoice/ASR [csj]: egs/csj/ASR [libricss]: egs/libricss/SURT +[libritts_asr]: egs/libritts/ASR [libriheavy]: egs/libriheavy/ASR [mgb2]: egs/mgb2/ASR [spgispeech]: egs/spgispeech/ASR @@ -380,3 +382,4 @@ Please see: [![Open In Colab](https://colab.research.google.com/assets/colab-bad [vctk]: egs/vctk/TTS [ljspeech]: egs/ljspeech/TTS +[libritts_tts]: egs/libritts/TTS diff --git a/egs/libritts/CODEC/encodec/train.py b/egs/libritts/CODEC/encodec/train.py index bf231c5b66..a4f2eb7ab7 100755 --- a/egs/libritts/CODEC/encodec/train.py +++ b/egs/libritts/CODEC/encodec/train.py @@ -138,7 +138,7 @@ def get_parser(): parser.add_argument( "--save-every-n", type=int, - default=1, + default=5, help="""Save checkpoint after processing this number of epochs" periodically. We save checkpoint to exp-dir/ whenever params.cur_epoch % save_every_n == 0. The checkpoint filename @@ -1093,14 +1093,14 @@ def run(rank, world_size, args): rank=rank, ) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer_g=optimizer_g, - # optimizer_d=optimizer_d, - # params=params, - # ) + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/libritts/CODEC/prepare.sh b/egs/libritts/CODEC/prepare.sh index 6a471c3adc..da04249ac3 100755 --- a/egs/libritts/CODEC/prepare.sh +++ b/egs/libritts/CODEC/prepare.sh @@ -45,12 +45,11 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # to $dl_dir/LibriTTS mkdir -p data/manifests if [ ! -e data/manifests/.libritts.done ]; then - lhotse prepare libritts --num-jobs 32 $dl_dir/LibriTTS data/manifests + lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests touch data/manifests/.libritts.done fi fi - if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Compute Spectrogram for LibriTTS" mkdir -p data/spectrogram @@ -64,7 +63,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ - <(gunzip -c /data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz fi diff --git a/egs/libritts/TTS/README.md b/egs/libritts/TTS/README.md new file mode 100644 index 0000000000..4d4fb85803 --- /dev/null +++ b/egs/libritts/TTS/README.md @@ -0,0 +1,51 @@ +# Introduction + +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +The main differences from the LibriSpeech corpus are listed below: +1. The audio files are at 24kHz sampling rate. +2. The speech is split at sentence breaks. +3. Both original and normalized texts are included. +4. Contextual information (e.g., neighbouring sentences) can be extracted. +5. Utterances with significant background noise are excluded. +For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + +# VITS + +This recipe provides a VITS model trained on the LibriTTS dataset. + +Pretrained model can be found [here](https://huggingface.co/zrjin/icefall-tts-libritts-vits-2024-10-30). + +The training command is given below: +``` +export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +./vits/train.py \ + --world-size 4 \ + --num-epochs 400 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir vits/exp \ + --max-duration 500 +``` + +To inference, use: +``` +./vits/infer.py \ + --exp-dir vits/exp \ + --epoch 400 \ + --tokens data/tokens.txt +``` diff --git a/egs/libritts/TTS/local/compute_spectrogram_libritts.py b/egs/libritts/TTS/local/compute_spectrogram_libritts.py new file mode 120000 index 0000000000..5a6ebba58c --- /dev/null +++ b/egs/libritts/TTS/local/compute_spectrogram_libritts.py @@ -0,0 +1 @@ +../../CODEC/local/compute_spectrogram_libritts.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_token_file.py b/egs/libritts/TTS/local/prepare_token_file.py new file mode 120000 index 0000000000..afc29a22ba --- /dev/null +++ b/egs/libritts/TTS/local/prepare_token_file.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/prepare_token_file.py \ No newline at end of file diff --git a/egs/libritts/TTS/local/prepare_tokens_libritts.py b/egs/libritts/TTS/local/prepare_tokens_libritts.py new file mode 100755 index 0000000000..faeb611f5d --- /dev/null +++ b/egs/libritts/TTS/local/prepare_tokens_libritts.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao, +# Zengrui Jin,) +# 2024 Tsinghua University (authors: Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file reads the texts in given manifest and save the new cuts with phoneme tokens. +""" + +import logging +from pathlib import Path + +import tacotron_cleaner.cleaners +from lhotse import CutSet, load_manifest +from piper_phonemize import phonemize_espeak +from tqdm.auto import tqdm + + +def remove_punc_to_upper(text: str) -> str: + text = text.replace("‘", "'") + text = text.replace("’", "'") + tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'") + s_list = [x.upper() if x in tokens else " " for x in text] + s = " ".join("".join(s_list).split()).strip() + return s + + +def prepare_tokens_libritts(): + output_dir = Path("data/spectrogram") + prefix = "libritts" + suffix = "jsonl.gz" + partitions = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-all-shuf", + "train-clean-460", + # "train-clean-100", + # "train-clean-360", + # "train-other-500", + ) + + for partition in partitions: + cut_set = load_manifest(output_dir / f"{prefix}_cuts_{partition}.{suffix}") + + new_cuts = [] + for cut in tqdm(cut_set): + # Each cut only contains one supervision + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut) + text = cut.supervisions[0].text + # Text normalization + text = tacotron_cleaner.cleaners.custom_english_cleaners(text) + # Convert to phonemes + tokens_list = phonemize_espeak(text, "en-us") + tokens = [] + for t in tokens_list: + tokens.extend(t) + cut.tokens = tokens + cut.supervisions[0].normalized_text = remove_punc_to_upper(text) + + new_cuts.append(cut) + + new_cut_set = CutSet.from_cuts(new_cuts) + new_cut_set.to_file( + output_dir / f"{prefix}_cuts_with_tokens_{partition}.{suffix}" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + prepare_tokens_libritts() diff --git a/egs/libritts/TTS/local/validate_manifest.py b/egs/libritts/TTS/local/validate_manifest.py new file mode 120000 index 0000000000..b4d52ebca0 --- /dev/null +++ b/egs/libritts/TTS/local/validate_manifest.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/local/validate_manifest.py \ No newline at end of file diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh new file mode 100755 index 0000000000..44016e6d21 --- /dev/null +++ b/egs/libritts/TTS/prepare.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +stage=0 +stop_stage=100 +sampling_rate=24000 +nj=32 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: build monotonic_align lib" + if [ ! -d vits/monotonic_align/build ]; then + cd vits/monotonic_align + python setup.py build_ext --inplace + cd ../../ + else + log "monotonic_align lib already built" + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/LibriTTS, + # you can create a symlink + # + # ln -sfv /path/to/LibriTTS $dl_dir/LibriTTS + # + if [ ! -d $dl_dir/LibriTTS ]; then + lhotse download libritts $dl_dir + fi + + if [ ! -d $dl_dir/xvector_nnet_1a_libritts_clean_460 ]; then + log "Downloading x-vector" + + git clone https://huggingface.co/datasets/zrjin/xvector_nnet_1a_libritts_clean_460 $dl_dir/xvector_nnet_1a_libritts_clean_460 + + mkdir -p exp/xvector_nnet_1a/ + cp -r $dl_dir/xvector_nnet_1a_libritts_clean_460/* exp/xvector_nnet_1a/ + fi + +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare LibriTTS manifest" + # We assume that you have downloaded the LibriTTS corpus + # to $dl_dir/LibriTTS + mkdir -p data/manifests + if [ ! -e data/manifests/.libritts.done ]; then + lhotse prepare libritts --num-jobs ${nj} $dl_dir/LibriTTS data/manifests + touch data/manifests/.libritts.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Compute Spectrogram for LibriTTS" + mkdir -p data/spectrogram + if [ ! -e data/spectrogram/.libritts.done ]; then + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + touch data/spectrogram/.libritts.done + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 and + # train-other-500 together to form the training set. + if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then + cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + fi + + # Here we shuffle and combine the train-clean-100, train-clean-360 + # together to form the training set. + if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then + cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) | \ + shuf | gzip -c > data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz + fi + + if [ ! -e data/spectrogram/.libritts-validated.done ]; then + log "Validating data/spectrogram for LibriTTS" + ./local/validate_manifest.py \ + data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz + touch data/spectrogram/.libritts-validated.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare phoneme tokens for LibriTTS" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: + # refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend: + # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/spectrogram/.libritts_with_token.done ]; then + ./local/prepare_tokens_libritts.py + touch data/spectrogram/.libritts_with_token.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Generate token file" + # We assume you have installed piper_phonemize and espnet_tts_frontend. + # If not, please install them with: + # - piper_phonemize: + # refer to https://github.com/rhasspy/piper-phonemize, + # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 + # - espnet_tts_frontend: + # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ + if [ ! -e data/tokens.txt ]; then + ./local/prepare_token_file.py --tokens data/tokens.txt + fi +fi diff --git a/egs/libritts/TTS/shared b/egs/libritts/TTS/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/libritts/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/libritts/TTS/vits/duration_predictor.py b/egs/libritts/TTS/vits/duration_predictor.py new file mode 120000 index 0000000000..9972b476f9 --- /dev/null +++ b/egs/libritts/TTS/vits/duration_predictor.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/duration_predictor.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/flow.py b/egs/libritts/TTS/vits/flow.py new file mode 120000 index 0000000000..e65d91ea75 --- /dev/null +++ b/egs/libritts/TTS/vits/flow.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/flow.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/generator.py b/egs/libritts/TTS/vits/generator.py new file mode 120000 index 0000000000..611679bfa8 --- /dev/null +++ b/egs/libritts/TTS/vits/generator.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/generator.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/hifigan.py b/egs/libritts/TTS/vits/hifigan.py new file mode 120000 index 0000000000..5ac025de72 --- /dev/null +++ b/egs/libritts/TTS/vits/hifigan.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/hifigan.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/infer.py b/egs/libritts/TTS/vits/infer.py new file mode 100755 index 0000000000..6756786061 --- /dev/null +++ b/egs/libritts/TTS/vits/infer.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +# +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script performs model inference on test set. + +Usage: +./vits/infer.py \ + --epoch 1000 \ + --exp-dir ./vits/exp \ + --max-duration 500 +""" + + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import k2 +import numpy as np +import torch +import torch.nn as nn +import torchaudio +from lhotse.features.io import KaldiReader +from tokenizer import Tokenizer +from train import get_model, get_params +from tts_datamodule import LibrittsTtsDataModule + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=1000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +def infer_dataset( + dl: torch.utils.data.DataLoader, + subset: str, + params: AttributeDict, + model: nn.Module, + tokenizer: Tokenizer, + speaker_map: KaldiReader, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + # Background worker save audios to disk. + def _save_worker( + subset: str, + batch_size: int, + cut_ids: List[str], + audio: torch.Tensor, + audio_pred: torch.Tensor, + audio_lens: List[int], + audio_lens_pred: List[int], + ): + for i in range(batch_size): + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_gt.wav"), + audio[i : i + 1, : audio_lens[i]], + sample_rate=params.sampling_rate, + ) + torchaudio.save( + str(params.save_wav_dir / subset / f"{cut_ids[i]}_pred.wav"), + audio_pred[i : i + 1, : audio_lens_pred[i]], + sample_rate=params.sampling_rate, + ) + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + tokens = batch["tokens"] + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + sids = ["_".join(cut_id.split("_")[:2]) for cut_id in cut_ids] + spembs = ( + torch.Tensor(np.array([speaker_map.read(sid) for sid in sids])) + .squeeze(1) + .to(device) + ) + + audio_pred, _, durations = model.inference_batch( + text=tokens, + text_lengths=tokens_lens, + spembs=spembs, + ) + audio_pred = audio_pred.detach().cpu() + # convert to samples + audio_lens_pred = ( + (durations.sum(1) * params.frame_shift).to(dtype=torch.int64).tolist() + ) + + futures.append( + executor.submit( + _save_worker, + subset, + batch_size, + cut_ids, + audio, + audio_pred, + audio_lens, + audio_lens_pred, + ) + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + # return results + for f in futures: + f.result() + + +@torch.no_grad() +def main(): + parser = get_parser() + LibrittsTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + # we need cut ids to display recognition results. + args.return_cuts = True + libritts = LibrittsTtsDataModule(args) + + logging.info(f"Device: {device}") + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + + model.to(device) + model.eval() + + num_param_g = sum([p.numel() for p in model.generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in model.discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + test_clean_cuts = libritts.test_clean_cuts() + test_clean_speaker_map = libritts.test_clean_xvector() + test_clean_dl = libritts.test_dataloaders(test_clean_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_clean_speaker_map = libritts.dev_clean_xvector() + dev_clean_dl = libritts.dev_dataloaders(dev_clean_cuts) + + infer_sets = { + "test-clean": (test_clean_dl, test_clean_speaker_map), + "dev-clean": (dev_clean_dl, dev_clean_speaker_map), + } + + for subset, data in infer_sets.items(): + save_wav_dir = params.res_dir / "wav" / subset + save_wav_dir.mkdir(parents=True, exist_ok=True) + dl, speaker_map = data + + logging.info(f"Processing {subset} set, saving to {save_wav_dir}") + + infer_dataset( + dl=dl, + subset=subset, + params=params, + model=model, + tokenizer=tokenizer, + speaker_map=speaker_map, + ) + + logging.info(f"Wav files are saved to {params.save_wav_dir}") + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/vits/loss.py b/egs/libritts/TTS/vits/loss.py new file mode 120000 index 0000000000..672e5ff68d --- /dev/null +++ b/egs/libritts/TTS/vits/loss.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/loss.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/monotonic_align b/egs/libritts/TTS/vits/monotonic_align new file mode 120000 index 0000000000..71934e7cca --- /dev/null +++ b/egs/libritts/TTS/vits/monotonic_align @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/monotonic_align \ No newline at end of file diff --git a/egs/libritts/TTS/vits/posterior_encoder.py b/egs/libritts/TTS/vits/posterior_encoder.py new file mode 120000 index 0000000000..41d64a3a66 --- /dev/null +++ b/egs/libritts/TTS/vits/posterior_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/posterior_encoder.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/residual_coupling.py b/egs/libritts/TTS/vits/residual_coupling.py new file mode 120000 index 0000000000..f979adbf00 --- /dev/null +++ b/egs/libritts/TTS/vits/residual_coupling.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/residual_coupling.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/test_onnx.py b/egs/libritts/TTS/vits/test_onnx.py new file mode 100755 index 0000000000..ae6587338e --- /dev/null +++ b/egs/libritts/TTS/vits/test_onnx.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script is used to test the exported onnx model by vits/export-onnx.py + +Use the onnx model to generate a wav: +./vits/test_onnx.py \ + --model-filename vits/exp/vits-epoch-1000.onnx \ + --tokens data/tokens.txt +""" + + +import argparse +import logging +from pathlib import Path + +import onnxruntime as ort +import torch +import torchaudio +from tokenizer import Tokenizer + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the onnx model.", + ) + + parser.add_argument( + "--speakers", + type=Path, + default=Path("data/speakers.txt"), + help="Path to speakers.txt file.", + ) + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + return parser + + +class OnnxModel: + def __init__(self, model_filename: str): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.model = ort.InferenceSession( + model_filename, + sess_options=self.session_opts, + providers=["CPUExecutionProvider"], + ) + logging.info(f"{self.model.get_modelmeta().custom_metadata_map}") + + def __call__( + self, tokens: torch.Tensor, tokens_lens: torch.Tensor, speaker: torch.Tensor + ) -> torch.Tensor: + """ + Args: + tokens: + A 1-D tensor of shape (1, T) + Returns: + A tensor of shape (1, T') + """ + noise_scale = torch.tensor([0.667], dtype=torch.float32) + noise_scale_dur = torch.tensor([0.8], dtype=torch.float32) + alpha = torch.tensor([1.0], dtype=torch.float32) + + out = self.model.run( + [ + self.model.get_outputs()[0].name, + ], + { + self.model.get_inputs()[0].name: tokens.numpy(), + self.model.get_inputs()[1].name: tokens_lens.numpy(), + self.model.get_inputs()[2].name: noise_scale.numpy(), + self.model.get_inputs()[3].name: alpha.numpy(), + self.model.get_inputs()[4].name: noise_scale_dur.numpy(), + self.model.get_inputs()[5].name: speaker.numpy(), + }, + )[0] + return torch.from_numpy(out) + + +def main(): + args = get_parser().parse_args() + + tokenizer = Tokenizer(args.tokens) + + with open(args.speakers) as f: + speaker_map = {line.strip(): i for i, line in enumerate(f)} + args.num_spks = len(speaker_map) + + logging.info("About to create onnx model") + model = OnnxModel(args.model_filename) + + text = "I went there to see the land, the people and how their system works, end quote." + tokens = tokenizer.texts_to_token_ids( + [text], intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = torch.tensor(tokens) # (1, T) + tokens_lens = torch.tensor([tokens.shape[1]], dtype=torch.int64) # (1, T) + speaker = torch.tensor([1], dtype=torch.int64) # (1, ) + audio = model(tokens, tokens_lens, speaker) # (1, T') + + torchaudio.save(str("test_onnx.wav"), audio, sample_rate=22050) + logging.info("Saved to test_onnx.wav") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/vits/text_encoder.py b/egs/libritts/TTS/vits/text_encoder.py new file mode 120000 index 0000000000..0efba277e1 --- /dev/null +++ b/egs/libritts/TTS/vits/text_encoder.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/text_encoder.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/tokenizer.py b/egs/libritts/TTS/vits/tokenizer.py new file mode 120000 index 0000000000..057b0dc4b1 --- /dev/null +++ b/egs/libritts/TTS/vits/tokenizer.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/tokenizer.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/train.py b/egs/libritts/TTS/vits/train.py new file mode 100755 index 0000000000..447fbcf5db --- /dev/null +++ b/egs/libritts/TTS/vits/train.py @@ -0,0 +1,1015 @@ +#!/usr/bin/env python3 +# Copyright 2023-2024 Xiaomi Corporation (Author: Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import numpy as np +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse.cut import Cut +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from tokenizer import Tokenizer +from torch.cuda.amp import GradScaler, autocast +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim import Optimizer +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import LibrittsTtsDataModule +from utils import MetricsTracker, plot_feature, save_checkpoint +from vits import VITS + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=1000, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="vits/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/tokens.txt", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--lr", type=float, default=2.0e-4, help="The base learning rate." + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=20, + help="""Save checkpoint after processing this number of epochs" + periodically. We save checkpoint to exp-dir/ whenever + params.cur_epoch % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/epoch-{params.cur_epoch}.pt'. + Since it will take around 1000 epochs, we suggest using a large + save_every_n to save disk space. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + # training params + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": -1, # 0 + "log_interval": 50, + "valid_interval": 200, + "env_info": get_env_info(), + "sampling_rate": 24000, + "frame_shift": 256, + "frame_length": 1024, + "feature_dim": 513, # 1024 // 2 + 1, 1024 is fft_length + "n_mels": 80, + "lambda_adv": 1.0, # loss scaling coefficient for adversarial loss + "lambda_mel": 45.0, # loss scaling coefficient for Mel loss + "lambda_feat_match": 2.0, # loss scaling coefficient for feat match loss + "lambda_dur": 1.0, # loss scaling coefficient for duration loss + "lambda_kl": 1.0, # loss scaling coefficient for KL divergence loss + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, model: nn.Module +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint(filename, model=model) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def get_model(params: AttributeDict) -> nn.Module: + mel_loss_params = { + "n_mels": params.n_mels, + "frame_length": params.frame_length, + "frame_shift": params.frame_shift, + } + generator_params = { + "hidden_channels": 192, + "spks": None, + "langs": None, + "spk_embed_dim": 512, + "global_channels": 256, + "segment_size": 32, + "text_encoder_attention_heads": 2, + "text_encoder_ffn_expand": 4, + "text_encoder_cnn_module_kernel": 5, + "text_encoder_blocks": 6, + "text_encoder_dropout_rate": 0.1, + "decoder_kernel_size": 7, + "decoder_channels": 512, + "decoder_upsample_scales": [8, 8, 2, 2], + "decoder_upsample_kernel_sizes": [16, 16, 4, 4], + "decoder_resblock_kernel_sizes": [3, 7, 11], + "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "use_weight_norm_in_decoder": True, + "posterior_encoder_kernel_size": 5, + "posterior_encoder_layers": 16, + "posterior_encoder_stacks": 1, + "posterior_encoder_base_dilation": 1, + "posterior_encoder_dropout_rate": 0.0, + "use_weight_norm_in_posterior_encoder": True, + "flow_flows": 4, + "flow_kernel_size": 5, + "flow_base_dilation": 1, + "flow_layers": 4, + "flow_dropout_rate": 0.0, + "use_weight_norm_in_flow": True, + "use_only_mean_in_flow": True, + "stochastic_duration_predictor_kernel_size": 3, + "stochastic_duration_predictor_dropout_rate": 0.5, + "stochastic_duration_predictor_flows": 4, + "stochastic_duration_predictor_dds_conv_layers": 3, + } + model = VITS( + vocab_size=params.vocab_size, + feature_dim=params.feature_dim, + sampling_rate=params.sampling_rate, + generator_params=generator_params, + mel_loss_params=mel_loss_params, + lambda_adv=params.lambda_adv, + lambda_mel=params.lambda_mel, + lambda_feat_match=params.lambda_feat_match, + lambda_dur=params.lambda_dur, + lambda_kl=params.lambda_kl, + ) + return model + + +def prepare_input( + batch: dict, + tokenizer: Tokenizer, + device: torch.device, + speaker_map: KaldiReader, +): + """Parse batch data""" + + def parse_sids(batch: dict) -> List[str]: + return ["_".join(cut.id.split("_")[:2]) for cut in batch["cut"]] + + audio = batch["audio"].to(device) + features = batch["features"].to(device) + audio_lens = batch["audio_lens"].to(device) + features_lens = batch["features_lens"].to(device) + tokens = batch["tokens"] + spembs = ( + torch.Tensor(np.array([speaker_map.read(sid) for sid in parse_sids(batch)])) + .squeeze(1) + .to(device) + ) + + tokens = tokenizer.tokens_to_token_ids( + tokens, intersperse_blank=True, add_sos=True, add_eos=True + ) + tokens = k2.RaggedTensor(tokens) + row_splits = tokens.shape.row_splits(1) + tokens_lens = row_splits[1:] - row_splits[:-1] + tokens = tokens.to(device) + tokens_lens = tokens_lens.to(device) + # a tensor of shape (B, T) + tokens = tokens.pad(mode="constant", padding_value=tokenizer.pad_id) + + return audio, audio_lens, features, features_lens, tokens, tokens_lens, spembs + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + optimizer_g: Optimizer, + optimizer_d: Optimizer, + scheduler_g: LRSchedulerType, + scheduler_d: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + dev_dl: torch.utils.data.DataLoader, + train_speaker_map: KaldiReader, + dev_speaker_map: KaldiReader, + scaler: GradScaler, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + tokenizer: + Used to convert text to phonemes. + optimizer_g: + The optimizer for generator. + optimizer_d: + The optimizer for discriminator. + scheduler_g: + The learning rate scheduler for generator, we call step() every epoch. + scheduler_d: + The learning rate scheduler for discriminator, we call step() every epoch. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations in one epoch + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + params=params, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + spembs, + ) = prepare_input(batch, tokenizer, device, train_speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + try: + with autocast(enabled=params.use_fp16): + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=False, + ) + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + # update discriminator + optimizer_d.zero_grad() + scaler.scale(loss_d).backward() + scaler.step(optimizer_d) + + with autocast(enabled=params.use_fp16): + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=True, + return_sample=params.batch_idx_train % params.log_interval == 0, + ) + for k, v in stats_g.items(): + if "returned_sample" not in k: + loss_info[k] = v * batch_size + # update generator + optimizer_g.zero_grad() + scaler.scale(loss_g).backward() + scaler.step(optimizer_g) + scaler.update() + + # summary stats + tot_loss = tot_loss + loss_info + except: # noqa + save_bad_model() + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if params.batch_idx_train % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or ( + cur_grad_scale < 32.0 and params.batch_idx_train % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if params.batch_idx_train % params.log_interval == 0: + cur_lr_g = max(scheduler_g.get_last_lr()) + cur_lr_d = max(scheduler_d.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"global_batch_idx: {params.batch_idx_train}, batch size: {batch_size}, " + f"loss[{loss_info}], tot_loss[{tot_loss}], " + f"cur_lr_g: {cur_lr_g:.2e}, cur_lr_d: {cur_lr_d:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate_g", cur_lr_g, params.batch_idx_train + ) + tb_writer.add_scalar( + "train/learning_rate_d", cur_lr_d, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + if "returned_sample" in stats_g: + speech_hat_, speech_, mel_hat_, mel_ = stats_g["returned_sample"] + tb_writer.add_audio( + "train/speech_hat_", + speech_hat_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/speech_", + speech_, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_image( + "train/mel_hat_", + plot_feature(mel_hat_), + params.batch_idx_train, + dataformats="HWC", + ) + tb_writer.add_image( + "train/mel_", + plot_feature(mel_), + params.batch_idx_train, + dataformats="HWC", + ) + + if ( + params.batch_idx_train % params.valid_interval == 0 + and not params.print_diagnostics + ): + logging.info("Computing validation loss") + valid_info, (speech_hat, speech) = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + dev_dl=dev_dl, + dev_speaker_map=dev_speaker_map, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + tb_writer.add_audio( + "train/valid_speech_hat", + speech_hat, + params.batch_idx_train, + params.sampling_rate, + ) + tb_writer.add_audio( + "train/valid_speech", + speech, + params.batch_idx_train, + params.sampling_rate, + ) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: Tokenizer, + dev_dl: torch.utils.data.DataLoader, + dev_speaker_map: KaldiReader, + world_size: int = 1, + rank: int = 0, +) -> Tuple[MetricsTracker, Tuple[np.ndarray, np.ndarray]]: + """Run the validation process.""" + model.eval() + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + + # used to summary the stats over iterations + tot_loss = MetricsTracker() + returned_sample = None + + with torch.no_grad(): + for batch_idx, batch in enumerate(dev_dl): + batch_size = len(batch["tokens"]) + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + spembs, + ) = prepare_input(batch, tokenizer, device, dev_speaker_map) + + loss_info = MetricsTracker() + loss_info["samples"] = batch_size + + # forward discriminator + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=False, + ) + assert loss_d.requires_grad is False + for k, v in stats_d.items(): + loss_info[k] = v * batch_size + + # forward generator + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=True, + ) + assert loss_g.requires_grad is False + for k, v in stats_g.items(): + loss_info[k] = v * batch_size + + # summary stats + tot_loss = tot_loss + loss_info + + # infer for first batch: + if batch_idx == 0 and rank == 0: + inner_model = model.module if isinstance(model, DDP) else model + audio_pred, _, duration = inner_model.inference( + text=tokens[0, : tokens_lens[0].item()], + spembs=spembs[0], + ) + audio_pred = audio_pred.data.cpu().numpy() + audio_len_pred = ( + (duration.sum(0) * params.frame_shift).to(dtype=torch.int64).item() + ) + assert audio_len_pred == len(audio_pred), ( + audio_len_pred, + len(audio_pred), + ) + audio_gt = audio[0, : audio_lens[0].item()].data.cpu().numpy() + returned_sample = (audio_pred, audio_gt) + + if world_size > 1: + tot_loss.reduce(device) + + loss_value = tot_loss["generator_loss"] / tot_loss["samples"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss, returned_sample + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + tokenizer: Tokenizer, + optimizer_g: torch.optim.Optimizer, + optimizer_d: torch.optim.Optimizer, + train_speaker_map: KaldiReader, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + ( + audio, + audio_lens, + features, + features_lens, + tokens, + tokens_lens, + spembs, + ) = prepare_input(batch, tokenizer, device, train_speaker_map) + try: + # for discriminator + with autocast(enabled=params.use_fp16): + loss_d, stats_d = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=False, + ) + optimizer_d.zero_grad() + loss_d.backward() + # for generator + with autocast(enabled=params.use_fp16): + loss_g, stats_g = model( + text=tokens, + text_lengths=tokens_lens, + feats=features, + feats_lengths=features_lens, + speech=audio, + speech_lengths=audio_lens, + spembs=spembs, + forward_generator=True, + ) + optimizer_g.zero_grad() + loss_g.backward() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + + libritts = LibrittsTtsDataModule(args) + + if params.full_libri: + train_cuts = libritts.train_all_shuf_cuts() + train_speaker_map = libritts.train_all_shuf_xvector() + else: + train_cuts = libritts.train_clean_460_cuts() + train_speaker_map = libritts.train_clean_460_xvector() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + generator = model.generator + discriminator = model.discriminator + + num_param_g = sum([p.numel() for p in generator.parameters()]) + logging.info(f"Number of parameters in generator: {num_param_g}") + num_param_d = sum([p.numel() for p in discriminator.parameters()]) + logging.info(f"Number of parameters in discriminator: {num_param_d}") + logging.info(f"Total number of parameters: {num_param_g + num_param_d}") + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer_g = torch.optim.AdamW( + generator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + optimizer_d = torch.optim.AdamW( + discriminator.parameters(), lr=params.lr, betas=(0.8, 0.99), eps=1e-9 + ) + + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.999875) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.999875) + + if checkpoints is not None: + # load state_dict for optimizers + if "optimizer_g" in checkpoints: + logging.info("Loading optimizer_g state dict") + optimizer_g.load_state_dict(checkpoints["optimizer_g"]) + if "optimizer_d" in checkpoints: + logging.info("Loading optimizer_d state dict") + optimizer_d.load_state_dict(checkpoints["optimizer_d"]) + + # load state_dict for schedulers + if "scheduler_g" in checkpoints: + logging.info("Loading scheduler_g state dict") + scheduler_g.load_state_dict(checkpoints["scheduler_g"]) + if "scheduler_d" in checkpoints: + logging.info("Loading scheduler_d state dict") + scheduler_d.load_state_dict(checkpoints["scheduler_d"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = libritts.train_dataloaders(train_cuts) + + dev_clean_cuts = libritts.dev_clean_cuts() + dev_speaker_map = libritts.dev_clean_xvector() + dev_dl = libritts.dev_dataloaders(dev_clean_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + train_speaker_map=train_speaker_map, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + logging.info(f"Start epoch {epoch}") + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + params.cur_epoch = epoch + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + train_dl=train_dl, + dev_dl=dev_dl, + train_speaker_map=train_speaker_map, + dev_speaker_map=dev_speaker_map, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + if epoch % params.save_every_n == 0 or epoch == params.num_epochs: + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint( + filename=filename, + params=params, + model=model, + optimizer_g=optimizer_g, + optimizer_d=optimizer_d, + scheduler_g=scheduler_g, + scheduler_d=scheduler_d, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + if rank == 0: + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + # step per epoch + scheduler_g.step() + scheduler_d.step() + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibrittsTtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/vits/transform.py b/egs/libritts/TTS/vits/transform.py new file mode 120000 index 0000000000..962647408b --- /dev/null +++ b/egs/libritts/TTS/vits/transform.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/transform.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/tts_datamodule.py b/egs/libritts/TTS/vits/tts_datamodule.py new file mode 100644 index 0000000000..e98e49c1f1 --- /dev/null +++ b/egs/libritts/TTS/vits/tts_datamodule.py @@ -0,0 +1,432 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +LIBRITTS_SAMPLING_RATE = 24000 + + +class LibrittsTtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=False, + help="""When enabled, use the entire LibriTTS training set. + Otherwise, use the 460h clean subset.""", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/spectrogram"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speaker-embeds", + type=Path, + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + sampling_rate = LIBRITTS_SAMPLING_RATE + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + sampling_rate = LIBRITTS_SAMPLING_RATE + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + sampling_rate = LIBRITTS_SAMPLING_RATE + config = SpectrogramConfig( + sampling_rate=sampling_rate, + frame_length=1024 / sampling_rate, # (in second), + frame_shift=256 / sampling_rate, # (in second) + use_fft_mag=True, + ) + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=OnTheFlyFeatures(Spectrogram(config)), + return_cuts=self.args.return_cuts, + ) + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=True, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_train-all-shuf.jsonl.gz" + ) + + @lru_cache() + def train_clean_460_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100 and train-clean-360 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir + / "libritts_cuts_with_tokens_train-clean-460.jsonl.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_with_tokens_test-other.jsonl.gz" + ) + + @lru_cache() + def train_all_shuf_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_clean_460_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for train-clean-460") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_train_clean_460" / "feats.scp") + ) + + @lru_cache() + def train_clean_100_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_clean_360_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def train_other_500_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def dev_clean_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for dev-clean") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_dev_clean" / "feats.scp") + ) + + @lru_cache() + def dev_other_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) + + @lru_cache() + def test_clean_xvector(self) -> KaldiReader: + logging.info("About to get speaker embeddings for test-clean") + return KaldiReader( + str(self.args.speaker_embeds / "xvectors_test_clean" / "feats.scp") + ) + + @lru_cache() + def test_other_xvector(self) -> KaldiReader: + raise NotImplementedError( + "Please implement the method to load speaker embeddings." + ) diff --git a/egs/libritts/TTS/vits/utils.py b/egs/libritts/TTS/vits/utils.py new file mode 120000 index 0000000000..085e764b43 --- /dev/null +++ b/egs/libritts/TTS/vits/utils.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/utils.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/vits.py b/egs/libritts/TTS/vits/vits.py new file mode 120000 index 0000000000..1f58cf6fea --- /dev/null +++ b/egs/libritts/TTS/vits/vits.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/vits.py \ No newline at end of file diff --git a/egs/libritts/TTS/vits/wavenet.py b/egs/libritts/TTS/vits/wavenet.py new file mode 120000 index 0000000000..28f0a78eeb --- /dev/null +++ b/egs/libritts/TTS/vits/wavenet.py @@ -0,0 +1 @@ +../../../ljspeech/TTS/vits/wavenet.py \ No newline at end of file diff --git a/egs/ljspeech/TTS/vits/generator.py b/egs/ljspeech/TTS/vits/generator.py index b9add9e828..521b0121f9 100644 --- a/egs/ljspeech/TTS/vits/generator.py +++ b/egs/ljspeech/TTS/vits/generator.py @@ -409,7 +409,12 @@ def inference( g = self.global_emb(sids.view(-1)).unsqueeze(-1) if self.spk_embed_dim is not None: # (B, global_channels, 1) - g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + if spembs.ndim == 2: + g_ = self.spemb_proj(F.normalize(spembs)).unsqueeze(-1) + elif spembs.ndim == 1: + g_ = self.spemb_proj(F.normalize(spembs.unsqueeze(0))).unsqueeze(-1) + else: + raise ValueError("spembs should be 1D or 2D (batch mode) tensor.") if g is None: g = g_ else: diff --git a/egs/ljspeech/TTS/vits/train.py b/egs/ljspeech/TTS/vits/train.py index 34b943765a..184ae79afa 100755 --- a/egs/ljspeech/TTS/vits/train.py +++ b/egs/ljspeech/TTS/vits/train.py @@ -542,13 +542,13 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", + "train/valid_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", + "train/valid_speech", speech, params.batch_idx_train, params.sampling_rate, diff --git a/egs/ljspeech/TTS/vits/vits.py b/egs/ljspeech/TTS/vits/vits.py index 0b9575cbde..a1fabf9ad6 100644 --- a/egs/ljspeech/TTS/vits/vits.py +++ b/egs/ljspeech/TTS/vits/vits.py @@ -622,6 +622,8 @@ def inference_batch( text: torch.Tensor, text_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, + spembs: Optional[torch.Tensor] = None, + lids: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None, noise_scale: float = 0.667, noise_scale_dur: float = 0.8, @@ -635,6 +637,8 @@ def inference_batch( text (Tensor): Input text index tensor (B, T_text). text_lengths (Tensor): Input text index tensor (B,). sids (Tensor): Speaker index tensor (B,). + spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). + lids (Tensor): Language index tensor (B,). noise_scale (float): Noise scale value for flow. noise_scale_dur (float): Noise scale value for duration predictor. alpha (float): Alpha parameter to control the speed of generated speech. @@ -650,6 +654,8 @@ def inference_batch( text=text, text_lengths=text_lengths, sids=sids, + spembs=spembs, + lids=lids, noise_scale=noise_scale, noise_scale_dur=noise_scale_dur, alpha=alpha, diff --git a/egs/vctk/TTS/vits/train.py b/egs/vctk/TTS/vits/train.py index 55bd693275..4686de1694 100755 --- a/egs/vctk/TTS/vits/train.py +++ b/egs/vctk/TTS/vits/train.py @@ -597,13 +597,13 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) tb_writer.add_audio( - "train/valdi_speech_hat", + "train/valid_speech_hat", speech_hat, params.batch_idx_train, params.sampling_rate, ) tb_writer.add_audio( - "train/valdi_speech", + "train/valid_speech", speech, params.batch_idx_train, params.sampling_rate,