Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
lipeng31 authored Dec 16, 2024
2 parents b29ab59 + d475de5 commit 6ffd624
Show file tree
Hide file tree
Showing 38 changed files with 629 additions and 378 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/ljspeech/TTS/run-matcha.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function infer() {

curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/inference.py \
./matcha/infer.py \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
python-version: [3.8]
python-version: [3.10.15]
fail-fast: false

steps:
Expand Down
11 changes: 1 addition & 10 deletions egs/libritts/TTS/local/prepare_tokens_libritts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,6 @@
from tqdm.auto import tqdm


def remove_punc_to_upper(text: str) -> str:
text = text.replace("‘", "'")
text = text.replace("’", "'")
tokens = set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'")
s_list = [x.upper() if x in tokens else " " for x in text]
s = " ".join("".join(s_list).split()).strip()
return s


def prepare_tokens_libritts():
output_dir = Path("data/spectrogram")
prefix = "libritts"
Expand Down Expand Up @@ -72,7 +63,7 @@ def prepare_tokens_libritts():
for t in tokens_list:
tokens.extend(t)
cut.tokens = tokens
cut.supervisions[0].normalized_text = remove_punc_to_upper(text)
cut.supervisions[0].normalized_text = text

new_cuts.append(cut)

Expand Down
2 changes: 1 addition & 1 deletion egs/libritts/TTS/prepare.sh
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then
cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \
<(gunzip -c data/spectrogram/libritts_cuts_train-clean-360.jsonl.gz) \
<(gunzip -c data/spectrogramlibritts_cuts_train-other-500.jsonl.gz) | \
<(gunzip -c data/spectrogram/libritts_cuts_train-other-500.jsonl.gz) | \
shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz
fi

Expand Down
4 changes: 2 additions & 2 deletions egs/ljspeech/TTS/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ To inference, use:

wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/inference \
./matcha/infer.py \
--exp-dir ./matcha/exp-new-3 \
--epoch 4000 \
--tokens ./data/tokens.txt \
--vocoder ./generator_v1 \
--input-text "how are you doing?"
--input-text "how are you doing?" \
--output-wav ./generated.wav
```

Expand Down
1 change: 1 addition & 0 deletions egs/ljspeech/TTS/local/audio.py
91 changes: 3 additions & 88 deletions egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,102 +27,17 @@
import argparse
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Union

import numpy as np
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from matcha.audio import mel_spectrogram

from icefall.utils import get_executor


@dataclass
class MyFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float


@register_extractor
class MyFbank(FeatureExtractor):

name = "MyFbank"
config_type = MyFbankConfig

def __init__(self, config):
super().__init__(config=config)

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels

def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape

mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)

assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape

num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)

if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)

return mel.numpy()

@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
Expand All @@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
Expand All @@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)

extractor = MyFbank(config)
extractor = MatchaFbank(config)

with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
Expand Down
1 change: 1 addition & 0 deletions egs/ljspeech/TTS/local/fbank.py
1 change: 0 additions & 1 deletion egs/ljspeech/TTS/local/validate_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import logging
from pathlib import Path

from compute_fbank_ljspeech import MyFbank
from lhotse import CutSet, load_manifest_lazy
from lhotse.dataset.speech_synthesis import validate_for_tts

Expand Down
1 change: 0 additions & 1 deletion egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py

This file was deleted.

2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/matcha/export_onnx_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import onnx
import torch
from inference import load_vocoder
from infer import load_vocoder


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
Expand Down
88 changes: 88 additions & 0 deletions egs/ljspeech/TTS/matcha/fbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Union

import numpy as np
import torch
from audio import mel_spectrogram
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames


@dataclass
class MatchaFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float


@register_extractor
class MatchaFbank(FeatureExtractor):

name = "MatchaFbank"
config_type = MatchaFbankConfig

def __init__(self, config):
super().__init__(config=config)

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels

def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape

mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)

assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape

num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)

if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)

return mel.numpy()

@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate
Loading

0 comments on commit 6ffd624

Please sign in to comment.