From cdaa98a8563c5238c55fda4f60a1f3c839ea9fa8 Mon Sep 17 00:00:00 2001 From: PRADUMNA Date: Fri, 5 Apr 2024 14:02:54 -0700 Subject: [PATCH] =?UTF-8?q?=E2=80=9Ctfgridnet=5Fintegration=5Ffrom=5FESPNE?= =?UTF-8?q?T=5Flibrary=5Ffor=5Fdenoising=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../preprocessing/denoise_and_vad_audio.py | 88 +- .../preprocessing/tfgridnet/enh/decoder.py | 287 +++++ .../preprocessing/tfgridnet/enh/encoder.py | 299 +++++ .../tfgridnet/enh/layers_complex_utils.py | 34 + .../tfgridnet/enh/loss_criterion.py | 141 +++ .../preprocessing/tfgridnet/enh/separator.py | 710 +++++++++++ .../preprocessing/tfgridnet/enh/tcn.py | 362 ++++++ .../preprocessing/tfgridnet/enh/wrappers.py | 177 +++ .../preprocessing/tfgridnet/enh_inference.py | 396 ++++++ .../tfgridnet/layers/inversible_interface.py | 13 + .../tfgridnet/layers/nets_utils.py | 583 +++++++++ .../preprocessing/tfgridnet/layers/stft.py | 215 ++++ .../preprocessing/tfgridnet/mask.py | 129 ++ .../preprocessing/tfgridnet/tasks/abs_task.py | 296 +++++ .../preprocessing/tfgridnet/tasks/enh.py | 458 +++++++ .../tfgridnet/torch_utils/device_functions.py | 31 + .../torch_utils/get_layer_from_string.py | 43 + .../tfgridnet/torch_utils/initialize.py | 125 ++ .../tfgridnet/train/class_choices.py | 90 ++ .../tfgridnet/train/espnet_model.py | 282 +++++ .../tfgridnet/train/preprocessor.py | 1096 +++++++++++++++++ 21 files changed, 5827 insertions(+), 28 deletions(-) create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/decoder.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/encoder.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/layers_complex_utils.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/loss_criterion.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/separator.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/tcn.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh/wrappers.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/enh_inference.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/layers/inversible_interface.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/layers/nets_utils.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/layers/stft.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/mask.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/tasks/abs_task.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/tasks/enh.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/device_functions.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/get_layer_from_string.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/initialize.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/train/class_choices.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/train/espnet_model.py create mode 100644 examples/speech_synthesis/preprocessing/tfgridnet/train/preprocessor.py diff --git a/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py b/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py index 4e13b38a5d..b65cc77b41 100644 --- a/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py +++ b/examples/speech_synthesis/preprocessing/denoise_and_vad_audio.py @@ -26,6 +26,7 @@ SCALE ) from examples.speech_to_text.data_utils import save_df_to_tsv +from examples.speech_synthesis.preprocessing.tfgridnet.enh_inference import SeparateSpeech log = logging.getLogger(__name__) @@ -78,7 +79,7 @@ def write(wav, filename, sr=16_000): def process(args): - # making sure we are requested either denoise or vad + # Making sure we are requested either denoise or vad if not args.denoise and not args.vad: log.error("No denoise or vad is requested.") return @@ -91,12 +92,6 @@ def process(args): out_vad = Path(args.output_dir).absolute().joinpath(PATHS[1]) out_vad.mkdir(parents=True, exist_ok=True) - log.info("Loading pre-trained speech enhancement model...") - model = master64().to(args.device) - - log.info("Building the VAD model...") - vad = webrtcvad.Vad(int(args.vad_agg_level)) - # preparing the output dict output_dict = defaultdict(list) @@ -104,18 +99,45 @@ def process(args): with open(args.audio_manifest, "r") as f: manifest_dict = csv.DictReader(f, delimiter="\t") for row in tqdm(manifest_dict): - filename = str(row["audio"]) - - final_output = filename - keep_sample = True - n_frames = row["n_frames"] - snr = -1 - if args.denoise: - output_path_denoise = out_denoise.joinpath(Path(filename).name) - # convert to 16khz in case we use a differet sr + filename = str(row["audio"]) + + final_output = filename + keep_sample = True + n_frames = row["n_frames"] + snr = -1 + # Denoise + if args.denoise: + # Load pre-trained speech enhancement model and build VAD model + log.info("Loading SeperateSpeech(TFGridnet) enhancement model...") + if args.model == "SeparateSpeech": + + log.info(f"Training Configuration .yaml file: {args.config}") + log.info(f"Pre-trained model .pth file: {args.pth_model}") + model = SeparateSpeech( + train_config = args.config, + model_file= args.pth_model, + normalize_segment_scale=False, + show_progressbar=True, + ref_channel=4, + normalize_output_wav=True) + + output_path_denoise = out_denoise.joinpath(Path(f"SeperateSpeech_{filename}").name) + waveform, sr = torchaudio.load(filename) + waveform = waveform.to("cpu") + estimate = model(waveform) + estimate = torch.tensor(estimate) + torchaudio.save(output_path_denoise, estimate[0], 16_000, encoding="PCM_S", bits_per_sample=16) + + else: + + log.info("Loading pre-trained speech enhancement model...") + model = master64().to(args.device) + # Set the output path for denoised audio + output_path_denoise = out_denoise.joinpath(Path(f"master64_{filename}").name) + + # Convert to 16kHz if the sample rate is different tmp_path = convert_sr(final_output, 16000) - - # loading audio file and generating the enhanced version + # Load audio file and generate the enhanced version out, sr = torchaudio.load(tmp_path) out = out.to(args.device) estimate = model(out) @@ -126,7 +148,10 @@ def process(args): snr = snr.cpu().detach().numpy()[0][0] final_output = str(output_path_denoise) - if args.vad: + log.info("Building the VAD model...") + vad = webrtcvad.Vad(int(args.vad_agg_level)) + + if args.vad: output_path_vad = out_vad.joinpath(Path(filename).name) sr = torchaudio.info(final_output).sample_rate if sr in [16000, 32000, 48000]: @@ -140,37 +165,36 @@ def process(args): # apply VAD segment, sample_rate = apply_vad(vad, tmp_path) if len(segment) < sample_rate * MIN_T: - keep_sample = False - print(( + keep_sample = False + print(( f"WARNING: skip {filename} because it is too short " f"after VAD ({len(segment) / sample_rate} < {MIN_T})" - )) + )) else: if sample_rate != sr: tmp_path = generate_tmp_filename("wav") write_wave(tmp_path, segment, sample_rate) convert_sr(tmp_path, sr, - output_path=str(output_path_vad)) + output_path=str(output_path_vad)) else: write_wave(str(output_path_vad), segment, sample_rate) final_output = str(output_path_vad) segment, _ = torchaudio.load(final_output) - n_frames = segment.size(1) + n_frames = segment.size(1) - if keep_sample: + if keep_sample: output_dict["id"].append(row["id"]) output_dict["audio"].append(final_output) output_dict["n_frames"].append(n_frames) output_dict["tgt_text"].append(row["tgt_text"]) output_dict["speaker"].append(row["speaker"]) output_dict["src_text"].append(row["src_text"]) - output_dict["snr"].append(snr) + output_dict["snr"].append(snr) out_tsv_path = Path(args.output_dir) / Path(args.audio_manifest).name log.info(f"Saving manifest to {out_tsv_path.as_posix()}") save_df_to_tsv(pd.DataFrame.from_dict(output_dict), out_tsv_path) - def main(): parser = argparse.ArgumentParser() parser.add_argument("--audio-manifest", "-i", required=True, @@ -194,6 +218,14 @@ def main(): ) parser.add_argument("--denoise", action="store_true", help="apply a denoising") + parser.add_argument( + "--model", "-m", type=str, default="master64", + help="the speech enhancement model to be used: master64 | SeparateSpeech." + ) + parser.add_argument("--config", type=str, + help="Training Configuration file for SeparateSpeech model.") + parser.add_argument("--pth-model", type=str, + help="Path to the pre-trained model file for SeparateSpeech.") parser.add_argument("--vad", action="store_true", help="apply a VAD") args = parser.parse_args() @@ -201,4 +233,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/decoder.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/decoder.py new file mode 100644 index 0000000000..1ec513bb2d --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/decoder.py @@ -0,0 +1,287 @@ +from abc import ABC, abstractmethod +from typing import Tuple +import torch + +import torch_complex +from torch_complex.tensor import ComplexTensor +from examples.speech_synthesis.preprocessing.tfgridnet.enh.layers_complex_utils import is_torch_complex_tensor +from examples.speech_synthesis.preprocessing.tfgridnet.layers.stft import Stft + + +class AbsDecoder(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + fs: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def forward_streaming(self, input_frame: torch.Tensor): + raise NotImplementedError + + def streaming_merge(self, chunks: torch.Tensor, ilens: torch.tensor = None): + """streaming_merge. It merges the frame-level processed audio chunks + in the streaming *simulation*. It is noted that, in real applications, + the processed audio should be sent to the output channel frame by frame. + You may refer to this function to manage your streaming output buffer. + + Args: + chunks: List [(B, frame_size),] + ilens: [B] + Returns: + merge_audio: [B, T] + """ + + raise NotImplementedError + + +class ConvDecoder(AbsDecoder): + """Transposed Convolutional decoder for speech enhancement and separation""" + + def __init__( + self, + channel: int, + kernel_size: int, + stride: int, + ): + super().__init__() + self.convtrans1d = torch.nn.ConvTranspose1d( + channel, 1, kernel_size, bias=False, stride=stride + ) + + self.kernel_size = kernel_size + self.stride = stride + + def forward(self, input: torch.Tensor, ilens: torch.Tensor, fs: int = None): + """Forward. + + Args: + input (torch.Tensor): spectrum [Batch, T, F] + ilens (torch.Tensor): input lengths [Batch] + fs (int): sampling rate in Hz (Not used) + """ + input = input.transpose(1, 2) + batch_size = input.shape[0] + wav = self.convtrans1d(input, output_size=(batch_size, 1, ilens.max())) + wav = wav.squeeze(1) + + return wav, ilens + + def forward_streaming(self, input_frame: torch.Tensor): + return self.forward(input_frame, ilens=torch.LongTensor([self.kernel_size]))[0] + + def streaming_merge(self, chunks: torch.Tensor, ilens: torch.tensor = None): + """streaming_merge. It merges the frame-level processed audio chunks + in the streaming *simulation*. It is noted that, in real applications, + the processed audio should be sent to the output channel frame by frame. + You may refer to this function to manage your streaming output buffer. + + Args: + chunks: List [(B, frame_size),] + ilens: [B] + Returns: + merge_audio: [B, T] + """ + hop_size = self.stride + frame_size = self.kernel_size + + num_chunks = len(chunks) + batch_size = chunks[0].shape[0] + audio_len = ( + int(hop_size * num_chunks + frame_size - hop_size) + if not ilens + else ilens.max() + ) + + output = torch.zeros((batch_size, audio_len), dtype=chunks[0].dtype).to( + chunks[0].device + ) + + for i, chunk in enumerate(chunks): + output[:, i * hop_size : i * hop_size + frame_size] += chunk + + return output + + +class STFTDecoder(AbsDecoder): + """STFT decoder for speech enhancement and separation""" + + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window="hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + default_fs: int = 16000, + ): + super().__init__() + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + self.win_length = win_length if win_length else n_fft + self.n_fft = n_fft + self.hop_length = hop_length + self.window = window + self.center = center + self.default_fs = default_fs + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, input: ComplexTensor, ilens: torch.Tensor, fs: int = None): + """Forward. + + Args: + input (ComplexTensor): spectrum [Batch, T, (C,) F] + ilens (torch.Tensor): input lengths [Batch] + fs (int): sampling rate in Hz + If not None, reconfigure iSTFT window and hop lengths for a new + sampling rate while keeping their duration fixed. + """ + if not isinstance(input, ComplexTensor) and ( + is_torch_1_9_plus and not torch.is_complex(input) + ): + raise TypeError("Only support complex tensors for stft decoder") + if fs is not None: + self._reconfig_for_fs(fs) + + bs = input.size(0) + if input.dim() == 4: + multi_channel = True + # input: (Batch, T, C, F) -> (Batch * C, T, F) + input = input.transpose(1, 2).reshape(-1, input.size(1), input.size(3)) + else: + multi_channel = False + + # for supporting half-precision training + if input.dtype in (torch.float16, torch.bfloat16): + wav, wav_lens = self.stft.inverse(input.float(), ilens) + wav = wav.to(dtype=input.dtype) + elif ( + is_torch_complex_tensor(input) + and hasattr(torch, "complex32") + and input.dtype == torch.complex32 + ): + wav, wav_lens = self.stft.inverse(input.cfloat(), ilens) + wav = wav.to(dtype=input.dtype) + else: + wav, wav_lens = self.stft.inverse(input, ilens) + + if multi_channel: + # wav: (Batch * C, Nsamples) -> (Batch, Nsamples, C) + wav = wav.reshape(bs, -1, wav.size(1)).transpose(1, 2) + + self._reset_config() + return wav, wav_lens + + def _reset_config(self): + """Reset the configuration of iSTFT window and hop lengths.""" + self._reconfig_for_fs(self.default_fs) + + def _reconfig_for_fs(self, fs): + """Reconfigure iSTFT window and hop lengths for a new sampling rate + while keeping their duration fixed. + + Args: + fs (int): new sampling rate + """ # noqa: H405 + assert fs % self.default_fs == 0 or self.default_fs % fs == 0 + self.stft.n_fft = self.n_fft * fs // self.default_fs + self.stft.win_length = self.win_length * fs // self.default_fs + self.stft.hop_length = self.hop_length * fs // self.default_fs + + def _get_window_func(self): + window_func = getattr(torch, f"{self.window}_window") + window = window_func(self.win_length) + n_pad_left = (self.n_fft - window.shape[0]) // 2 + n_pad_right = self.n_fft - window.shape[0] - n_pad_left + return window + + def forward_streaming(self, input_frame: torch.Tensor): + """Forward. + + Args: + input (ComplexTensor): spectrum [Batch, 1, F] + output: wavs [Batch, 1, self.win_length] + """ + + input_frame = input_frame.real + 1j * input_frame.imag + output_wav = ( + torch.fft.irfft(input_frame) + if self.stft.onesided + else torch.fft.ifft(input_frame).real + ) + + output_wav = output_wav.squeeze(1) + + n_pad_left = (self.n_fft - self.win_length) // 2 + output_wav = output_wav[..., n_pad_left : n_pad_left + self.win_length] + + return output_wav * self._get_window_func() + + def streaming_merge(self, chunks, ilens=None): + """streaming_merge. It merges the frame-level processed audio chunks + in the streaming *simulation*. It is noted that, in real applications, + the processed audio should be sent to the output channel frame by frame. + You may refer to this function to manage your streaming output buffer. + + Args: + chunks: List [(B, frame_size),] + ilens: [B] + Returns: + merge_audio: [B, T] + """ # noqa: H405 + + frame_size = self.win_length + hop_size = self.hop_length + + num_chunks = len(chunks) + batch_size = chunks[0].shape[0] + audio_len = int(hop_size * num_chunks + frame_size - hop_size) + + output = torch.zeros((batch_size, audio_len), dtype=chunks[0].dtype).to( + chunks[0].device + ) + + for i, chunk in enumerate(chunks): + output[:, i * hop_size : i * hop_size + frame_size] += chunk + + window_sq = self._get_window_func().pow(2) + window_envelop = torch.zeros((batch_size, audio_len), dtype=chunks[0].dtype).to( + chunks[0].device + ) + for i in range(len(chunks)): + window_envelop[:, i * hop_size : i * hop_size + frame_size] += window_sq + output = output / window_envelop + + # We need to trim the front padding away if center. + start = (frame_size // 2) if self.center else 0 + end = -(frame_size // 2) if ilens.max() is None else start + ilens.max() + + return output[..., start:end] + + +class NullDecoder(AbsDecoder): + """Null decoder, return the same args.""" + + def __init__(self): + super().__init__() + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. The input should be the waveform already. + + Args: + input (torch.Tensor): wav [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + """ + return input, ilens diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/encoder.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/encoder.py new file mode 100644 index 0000000000..c8b01cce67 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/encoder.py @@ -0,0 +1,299 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from torch_complex.tensor import ComplexTensor + +from packaging.version import parse as V + +is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") + +from examples.speech_synthesis.preprocessing.tfgridnet.layers.stft import Stft + + +class AbsEncoder(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + fs: int = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError + + @property + @abstractmethod + def output_dim(self) -> int: + raise NotImplementedError + + def forward_streaming(self, input: torch.Tensor): + raise NotImplementedError + + def streaming_frame(self, audio: torch.Tensor): + """streaming_frame. It splits the continuous audio into frame-level + audio chunks in the streaming *simulation*. It is noted that this + function takes the entire long audio as input for a streaming simulation. + You may refer to this function to manage your streaming input + buffer in a real streaming application. + + Args: + audio: (B, T) + Returns: + chunked: List [(B, frame_size),] + """ + NotImplementedError + + +class ConvEncoder(AbsEncoder): + """Convolutional encoder for speech enhancement and separation""" + + def __init__( + self, + channel: int, + kernel_size: int, + stride: int, + ): + super().__init__() + self.conv1d = torch.nn.Conv1d( + 1, channel, kernel_size=kernel_size, stride=stride, bias=False + ) + self.stride = stride + self.kernel_size = kernel_size + + self._output_dim = channel + + @property + def output_dim(self) -> int: + return self._output_dim + + def forward(self, input: torch.Tensor, ilens: torch.Tensor, fs: int = None): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + fs (int): sampling rate in Hz (Not used) + Returns: + feature (torch.Tensor): mixed feature after encoder [Batch, flens, channel] + """ + assert input.dim() == 2, "Currently only support single channel input" + + input = torch.unsqueeze(input, 1) + + feature = self.conv1d(input) + feature = torch.nn.functional.relu(feature) + feature = feature.transpose(1, 2) + + flens = ( + torch.div(ilens - self.kernel_size, self.stride, rounding_mode="trunc") + 1 + ) + + return feature, flens + + def forward_streaming(self, input: torch.Tensor): + output, _ = self.forward(input, 0) + return output + + def streaming_frame(self, audio: torch.Tensor): + """streaming_frame. It splits the continuous audio into frame-level + audio chunks in the streaming *simulation*. It is noted that this + function takes the entire long audio as input for a streaming simulation. + You may refer to this function to manage your streaming input + buffer in a real streaming application. + + Args: + audio: (B, T) + Returns: + chunked: List [(B, frame_size),] + """ + batch_size, audio_len = audio.shape + + hop_size = self.stride + frame_size = self.kernel_size + + audio = [ + audio[:, i * hop_size : i * hop_size + frame_size] + for i in range((audio_len - frame_size) // hop_size + 1) + ] + + return audio + + +class STFTEncoder(AbsEncoder): + """STFT encoder for speech enhancement and separation""" + + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window="hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + use_builtin_complex: bool = True, + default_fs: int = 16000, + ): + super().__init__() + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + window=window, + center=center, + normalized=normalized, + onesided=onesided, + ) + + self._output_dim = n_fft // 2 + 1 if onesided else n_fft + self.use_builtin_complex = use_builtin_complex + self.win_length = win_length if win_length else n_fft + self.hop_length = hop_length + self.window = window + self.n_fft = n_fft + self.center = center + self.default_fs = default_fs + + @property + def output_dim(self) -> int: + return self._output_dim + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, input: torch.Tensor, ilens: torch.Tensor, fs: int = None): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + fs (int): sampling rate in Hz + If not None, reconfigure STFT window and hop lengths for a new + sampling rate while keeping their duration fixed. + Returns: + spectrum (ComplexTensor): [Batch, T, (C,) F] + flens (torch.Tensor): [Batch] + """ + if fs is not None: + self._reconfig_for_fs(fs) + # for supporting half-precision training + if input.dtype in (torch.float16, torch.bfloat16): + spectrum, flens = self.stft(input.float(), ilens) + spectrum = spectrum.to(dtype=input.dtype) + else: + spectrum, flens = self.stft(input, ilens) + if is_torch_1_9_plus and self.use_builtin_complex: + spectrum = torch.complex(spectrum[..., 0], spectrum[..., 1]) + else: + spectrum = ComplexTensor(spectrum[..., 0], spectrum[..., 1]) + + self._reset_config() + return spectrum, flens + + def _reset_config(self): + """Reset the configuration of STFT window and hop lengths.""" + self._reconfig_for_fs(self.default_fs) + + def _reconfig_for_fs(self, fs): + """Reconfigure STFT window and hop lengths for a new sampling rate + while keeping their duration fixed. + + Args: + fs (int): new sampling rate + """ # noqa: H405 + assert fs % self.default_fs == 0 or self.default_fs % fs == 0 + self.stft.n_fft = self.n_fft * fs // self.default_fs + self.stft.win_length = self.win_length * fs // self.default_fs + self.stft.hop_length = self.hop_length * fs // self.default_fs + + def _apply_window_func(self, input): + B = input.shape[0] + + window_func = getattr(torch, f"{self.window}_window") + window = window_func(self.win_length, dtype=input.dtype, device=input.device) + n_pad_left = (self.n_fft - window.shape[0]) // 2 + n_pad_right = self.n_fft - window.shape[0] - n_pad_left + + windowed = input * window + + windowed = torch.cat( + [torch.zeros(B, n_pad_left), windowed, torch.zeros(B, n_pad_right)], 1 + ) + return windowed + + def forward_streaming(self, input: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, frame_length] + Return: + B, 1, F + """ + + assert ( + input.dim() == 2 + ), "forward_streaming only support for single-channel input currently." + + windowed = self._apply_window_func(input) + + feature = ( + torch.fft.rfft(windowed) if self.stft.onesided else torch.fft.fft(windowed) + ) + feature = feature.unsqueeze(1) + if not (is_torch_1_9_plus and self.use_builtin_complex): + feature = ComplexTensor(feature.real, feature.imag) + + return feature + + def streaming_frame(self, audio): + """streaming_frame. It splits the continuous audio into frame-level + audio chunks in the streaming *simulation*. It is noted that this + function takes the entire long audio as input for a streaming simulation. + You may refer to this function to manage your streaming input + buffer in a real streaming application. + + Args: + audio: (B, T) + Returns: + chunked: List [(B, frame_size),] + """ # noqa: H405 + + if self.center: + pad_len = int(self.win_length // 2) + signal_dim = audio.dim() + extended_shape = [1] * (3 - signal_dim) + list(audio.size()) + # the default STFT pad mode is "reflect", + # which is not configurable in STFT encoder, + # so, here we just use "reflect mode" + audio = torch.nn.functional.pad( + audio.view(extended_shape), [pad_len, pad_len], "reflect" + ) + audio = audio.view(audio.shape[-signal_dim:]) + + _, audio_len = audio.shape + + n_frames = 1 + (audio_len - self.win_length) // self.hop_length + strides = list(audio.stride()) + + shape = list(audio.shape[:-1]) + [self.win_length, n_frames] + strides = strides + [self.hop_length] + + return audio.as_strided(shape, strides, storage_offset=0).unbind(dim=-1) + + +class NullEncoder(AbsEncoder): + """Null encoder.""" + + def __init__(self): + super().__init__() + + @property + def output_dim(self) -> int: + return 1 + + def forward(self, input: torch.Tensor, ilens: torch.Tensor): + """Forward. + + Args: + input (torch.Tensor): mixed speech [Batch, sample] + ilens (torch.Tensor): input lengths [Batch] + """ + return input, ilens diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/layers_complex_utils.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/layers_complex_utils.py new file mode 100644 index 0000000000..0880c368ba --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/layers_complex_utils.py @@ -0,0 +1,34 @@ +from typing import Sequence, Tuple, Union + +import torch +from torch_complex import functional as FC +from torch_complex.tensor import ComplexTensor + + +def new_complex_like( + ref: Union[torch.Tensor, ComplexTensor], + real_imag: Tuple[torch.Tensor, torch.Tensor], +): + if isinstance(ref, ComplexTensor): + return ComplexTensor(*real_imag) + elif is_torch_complex_tensor(ref): + return torch.complex(*real_imag) + else: + raise ValueError( + "Please update your PyTorch version to 1.9+ for complex support." + ) + + +def is_torch_complex_tensor(c): + return not isinstance(c, ComplexTensor) and torch.is_complex(c) + + +def to_complex(c): + # Convert to torch native complex + if isinstance(c, ComplexTensor): + c = c.real + 1j * c.imag + return c + elif torch.is_complex(c): + return c + else: + return torch.view_as_complex(c) diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/loss_criterion.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/loss_criterion.py new file mode 100644 index 0000000000..1d049fb8b3 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/loss_criterion.py @@ -0,0 +1,141 @@ +import logging +import math +from abc import ABC, abstractmethod + +import torch + +EPS = torch.finfo(torch.get_default_dtype()).eps + + +class AbsEnhLoss(torch.nn.Module, ABC): + """Base class for all Enhancement loss modules.""" + + # the name will be the key that appears in the reporter + @property + def name(self) -> str: + return NotImplementedError + + # This property specifies whether the criterion will only + # be evaluated during the inference stage + @property + def only_for_test(self) -> bool: + return False + + @abstractmethod + def forward( + self, + ref, + inf, + ) -> torch.Tensor: + # the return tensor should be shape of (batch) + raise NotImplementedError + + +class TimeDomainLoss(AbsEnhLoss, ABC): + """Base class for all time-domain Enhancement loss modules.""" + + @property + def name(self) -> str: + return self._name + + @property + def only_for_test(self) -> bool: + return self._only_for_test + + @property + def is_noise_loss(self) -> bool: + return self._is_noise_loss + + @property + def is_dereverb_loss(self) -> bool: + return self._is_dereverb_loss + + def __init__( + self, + name, + only_for_test=False, + is_noise_loss=False, + is_dereverb_loss=False, + ): + super().__init__() + # only used during validation + self._only_for_test = only_for_test + # only used to calculate the noise-related loss + self._is_noise_loss = is_noise_loss + # only used to calculate the dereverberation-related loss + self._is_dereverb_loss = is_dereverb_loss + if is_noise_loss and is_dereverb_loss: + raise ValueError( + "`is_noise_loss` and `is_dereverb_loss` cannot be True at the same time" + ) + if is_noise_loss and "noise" not in name: + name = name + "_noise" + if is_dereverb_loss and "dereverb" not in name: + name = name + "_dereverb" + self._name = name + + +class SISNRLoss(TimeDomainLoss): + """SI-SNR (or named SI-SDR) loss + + A more stable SI-SNR loss with clamp from `fast_bss_eval`. + + Attributes: + clamp_db: float + clamp the output value in [-clamp_db, clamp_db] + zero_mean: bool + When set to True, the mean of all signals is subtracted prior. + eps: float + Deprecated. Kept for compatibility. + """ + + def __init__( + self, + clamp_db=None, + zero_mean=True, + eps=None, + name=None, + only_for_test=False, + is_noise_loss=False, + is_dereverb_loss=False, + ): + _name = "si_snr_loss" if name is None else name + super().__init__( + _name, + only_for_test=only_for_test, + is_noise_loss=is_noise_loss, + is_dereverb_loss=is_dereverb_loss, + ) + + self.clamp_db = clamp_db + self.zero_mean = zero_mean + if eps is not None: + logging.warning("Eps is deprecated in si_snr loss, set clamp_db instead.") + if self.clamp_db is None: + self.clamp_db = -math.log10(eps / (1 - eps)) * 10 + + def forward(self, ref: torch.Tensor, est: torch.Tensor) -> torch.Tensor: + """SI-SNR forward. + + Args: + + ref: Tensor, (..., n_samples) + reference signal + est: Tensor (..., n_samples) + estimated signal + + Returns: + loss: (...,) + the SI-SDR loss (negative si-sdr) + """ + assert torch.is_tensor(est) and torch.is_tensor(ref), est + + si_snr = fast_bss_eval.si_sdr_loss( + est=est, + ref=ref, + zero_mean=self.zero_mean, + clamp_db=self.clamp_db, + pairwise=False, + ) + + return si_snr diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/separator.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/separator.py new file mode 100644 index 0000000000..dc3a47db43 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/separator.py @@ -0,0 +1,710 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple, Union + +import torch +from torch_complex.tensor import ComplexTensor + +from examples.speech_synthesis.preprocessing.tfgridnet.enh.tcn import TemporalConvNet + +import math +from collections import OrderedDict +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init +from torch.nn.parameter import Parameter + +from examples.speech_synthesis.preprocessing.tfgridnet.enh.decoder import STFTDecoder +from examples.speech_synthesis.preprocessing.tfgridnet.enh.encoder import STFTEncoder +from examples.speech_synthesis.preprocessing.tfgridnet.enh.layers_complex_utils import new_complex_like +from examples.speech_synthesis.preprocessing.tfgridnet.torch_utils.get_layer_from_string import get_layer + + +def is_torch_complex_tensor(c): + return not isinstance(c, ComplexTensor) and torch.is_complex(c) + + +def is_complex(c): + return isinstance(c, ComplexTensor) or is_torch_complex_tensor(c) + + +class AbsSeparator(torch.nn.Module, ABC): + @abstractmethod + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + additional: Optional[Dict] = None, + ) -> Tuple[Tuple[torch.Tensor], torch.Tensor, OrderedDict]: + raise NotImplementedError + + def forward_streaming( + self, + input_frame: torch.Tensor, + buffer=None, + ): + raise NotImplementedError + + @property + @abstractmethod + def num_spk(self): + raise NotImplementedError + + +class RNNSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + rnn_type: str = "blstm", + num_spk: int = 2, + predict_noise: bool = False, + nonlinear: str = "sigmoid", + layer: int = 3, + unit: int = 512, + dropout: float = 0.0, + ): + """RNN Separator + + Args: + input_dim: input feature dimension + rnn_type: string, select from 'blstm', 'lstm' etc. + bidirectional: bool, whether the inter-chunk RNN layers are bidirectional. + num_spk: number of speakers + predict_noise: whether to output the estimated noise signal + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + layer: int, number of stacked RNN layers. Default is 3. + unit: int, dimension of the hidden state. + dropout: float, dropout ratio. Default is 0. + """ + super().__init__() + + self._num_spk = num_spk + self.predict_noise = predict_noise + + self.rnn = RNN( + idim=input_dim, + elayers=layer, + cdim=unit, + hdim=unit, + dropout=dropout, + typ=rnn_type, + ) + + num_outputs = self.num_spk + 1 if self.predict_noise else self.num_spk + self.linear = torch.nn.ModuleList( + [torch.nn.Linear(unit, input_dim) for _ in range(num_outputs)] + ) + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.nonlinear = { + "sigmoid": torch.nn.Sigmoid(), + "relu": torch.nn.ReLU(), + "tanh": torch.nn.Tanh(), + }[nonlinear] + + def forward( + self, + input: Union[torch.Tensor, ComplexTensor], + ilens: torch.Tensor, + additional: Optional[Dict] = None, + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + additional (Dict or None): other data included in model + NOTE: not used in this model + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + + # if complex spectrum, + if is_complex(input): + feature = abs(input) + else: + feature = input + + x, ilens, _ = self.rnn(feature, ilens) + + masks = [] + + for linear in self.linear: + y = linear(x) + y = self.nonlinear(y) + masks.append(y) + + if self.predict_noise: + *masks, mask_noise = masks + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + if self.predict_noise: + others["noise1"] = input * mask_noise + + return masked, ilens, others + + @property + def num_spk(self): + return self._num_spk + + def forward_streaming(self, input_frame: torch.Tensor, states=None): + # input_frame # B, 1, N + + # if complex spectrum, + if is_complex(input_frame): + feature = abs(input_frame) + else: + feature = input_frame + + ilens = torch.ones(feature.shape[0], device=feature.device) + + x, _, states = self.rnn(feature, ilens, states) + + masks = [] + + for linear in self.linear: + y = linear(x) + y = self.nonlinear(y) + masks.append(y) + + if self.predict_noise: + *masks, mask_noise = masks + + masked = [input_frame * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + if self.predict_noise: + others["noise1"] = input * mask_noise + + return masked, states, others + + +class TCNSeparator(AbsSeparator): + def __init__( + self, + input_dim: int, + num_spk: int = 2, + predict_noise: bool = False, + layer: int = 8, + stack: int = 3, + bottleneck_dim: int = 128, + hidden_dim: int = 512, + kernel: int = 3, + causal: bool = False, + norm_type: str = "gLN", + nonlinear: str = "relu", + ): + """Temporal Convolution Separator + + Args: + input_dim: input feature dimension + num_spk: number of speakers + predict_noise: whether to output the estimated noise signal + layer: int, number of layers in each stack. + stack: int, number of stacks + bottleneck_dim: bottleneck dimension + hidden_dim: number of convolution channel + kernel: int, kernel size. + causal: bool, defalut False. + norm_type: str, choose from 'BN', 'gLN', 'cLN' + nonlinear: the nonlinear function for mask estimation, + select from 'relu', 'tanh', 'sigmoid' + """ + super().__init__() + + self._num_spk = num_spk + self.predict_noise = predict_noise + + if nonlinear not in ("sigmoid", "relu", "tanh"): + raise ValueError("Not supporting nonlinear={}".format(nonlinear)) + + self.tcn = TemporalConvNet( + N=input_dim, + B=bottleneck_dim, + H=hidden_dim, + P=kernel, + X=layer, + R=stack, + C=num_spk + 1 if predict_noise else num_spk, + norm_type=norm_type, + causal=causal, + mask_nonlinear=nonlinear, + ) + + def forward( + self, + input: Union[torch.Tensor, ComplexTensor], + ilens: torch.Tensor, + additional: Optional[Dict] = None, + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N] + ilens (torch.Tensor): input lengths [Batch] + additional (Dict or None): other data included in model + NOTE: not used in this model + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...] + ilens (torch.Tensor): (B,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + """ + # if complex spectrum + if is_complex(input): + feature = abs(input) + else: + feature = input + B, L, N = feature.shape + + feature = feature.transpose(1, 2) # B, N, L + + masks = self.tcn(feature) # B, num_spk, N, L + masks = masks.transpose(2, 3) # B, num_spk, L, N + if self.predict_noise: + *masks, mask_noise = masks.unbind(dim=1) # List[B, L, N] + else: + masks = masks.unbind(dim=1) # List[B, L, N] + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + if self.predict_noise: + others["noise1"] = input * mask_noise + + return masked, ilens, others + + def forward_streaming(self, input_frame: torch.Tensor, buffer=None): + # input_frame: B, 1, N + + B, _, N = input_frame.shape + + receptive_field = self.tcn.receptive_field + + if buffer is None: + buffer = torch.zeros((B, receptive_field, N), device=input_frame.device) + + buffer = torch.roll(buffer, shifts=-1, dims=1) + buffer[:, -1, :] = input_frame[:, 0, :] + + masked, ilens, others = self.forward(buffer, None) + + masked = [m[:, -1, :].unsqueeze(1) for m in masked] + + return masked, buffer, others + + @property + def num_spk(self): + return self._num_spk + + +class TFGridNetMasking(AbsSeparator): + """Offline TFGridNet + + Reference: + [1] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, + "TF-GridNet: Integrating Full- and Sub-Band Modeling for Speech Separation", + in arXiv preprint arXiv:2211.12433, 2022. + [2] Z.-Q. Wang, S. Cornell, S. Choi, Y. Lee, B.-Y. Kim, and S. Watanabe, + "TF-GridNet: Making Time-Frequency Domain Models Great Again for Monaural + Speaker Separation", in arXiv preprint arXiv:2209.03952, 2022. + + NOTES: + As outlined in the Reference, this model works best when trained with variance + normalized mixture input and target, e.g., with mixture of shape [batch, samples, + microphones], you normalize it by dividing with torch.std(mixture, (1, 2)). You + must do the same for the target signals. It is encouraged to do so when not using + scale-invariant loss functions such as SI-SDR. + + Args: + input_dim: placeholder, not used + n_srcs: number of output sources/speakers. + n_fft: stft window size. + stride: stft stride. + window: stft window type choose between 'hamming', 'hanning' or None. + n_imics: number of microphones channels (only fixed-array geometry supported). + n_layers: number of TFGridNet blocks. + lstm_hidden_units: number of hidden units in LSTM. + attn_n_head: number of heads in self-attention + attn_approx_qk_dim: approximate dimention of frame-level key and value tensors + emb_dim: embedding dimension + emb_ks: kernel size for unfolding and deconv1D + emb_hs: hop size for unfolding and deconv1D + activation: activation function to use in the whole TFGridNet model, + you can use any torch supported activation e.g. 'relu' or 'elu'. + eps: small epsilon for normalization layers. + use_builtin_complex: whether to use builtin complex type or not. + """ + + def __init__( + self, + input_dim, + n_srcs=2, + n_fft=128, + stride=64, + window="hann", + n_imics=1, + n_layers=6, + lstm_hidden_units=192, + attn_n_head=4, + attn_approx_qk_dim=512, + emb_dim=48, + emb_ks=4, + emb_hs=1, + activation="prelu", + eps=1.0e-5, + use_builtin_complex=False, + ref_channel=-1, + ): + super().__init__() + self.n_srcs = n_srcs + self.n_layers = n_layers + self.n_imics = n_imics + assert n_fft % 2 == 0 + n_freqs = n_fft // 2 + 1 + self.ref_channel = ref_channel + + self.enc = STFTEncoder( + n_fft, n_fft, stride, window=window, use_builtin_complex=use_builtin_complex + ) + self.dec = STFTDecoder(n_fft, n_fft, stride, window=window) + + t_ksize = 3 + ks, padding = (t_ksize, 3), (t_ksize // 2, 1) + self.conv = nn.Sequential( + nn.Conv2d(2 * n_imics, emb_dim, ks, padding=padding), + nn.GroupNorm(1, emb_dim, eps=eps), + ) + + self.blocks = nn.ModuleList([]) + for _ in range(n_layers): + self.blocks.append( + GridNetBlock( + emb_dim, + emb_ks, + emb_hs, + n_freqs, + lstm_hidden_units, + n_head=attn_n_head, + approx_qk_dim=attn_approx_qk_dim, + activation=activation, + eps=eps, + ) + ) + + self.deconv = nn.ConvTranspose2d(emb_dim, n_srcs * 2, ks, padding=padding) + + def forward( + self, + input: torch.Tensor, + ilens: torch.Tensor, + additional: Optional[Dict] = None, + ) -> Tuple[List[torch.Tensor], torch.Tensor, OrderedDict]: + """Forward. + + Args: + input (torch.Tensor): batched multi-channel audio tensor with + M audio channels and N samples [B, N, M] + ilens (torch.Tensor): input lengths [B] + additional (Dict or None): other data, currently unused in this model. + + Returns: + enhanced (List[Union(torch.Tensor)]): + [(B, T), ...] list of len n_srcs + of mono audio tensors with T samples. + ilens (torch.Tensor): (B,) + additional (Dict or None): other data, currently unused in this model, + we return it also in output. + """ + n_samples = input.shape[1] + if self.n_imics == 1: + assert len(input.shape) == 2 + input = input[..., None] # [B, N, M] + + mix_std_ = torch.std(input, dim=(1, 2), keepdim=True) # [B, 1, 1] + input = input / mix_std_ # RMS normalization + + batch = self.enc(input, ilens)[0] # [B, T, M, F] + batch0 = batch.transpose(1, 2) # [B, M, T, F] + batch = torch.cat((batch0.real, batch0.imag), dim=1) # [B, 2*M, T, F] + n_batch, _, n_frames, n_freqs = batch.shape + + batch = self.conv(batch) # [B, -1, T, F] + + for ii in range(self.n_layers): + batch = self.blocks[ii](batch) # [B, -1, T, F] + + mask = self.deconv(batch) # [B, n_srcs*2, T, F] + mask[mask > 5] = 5 + mask[mask < -5] = -5 + + mask = mask.view([n_batch, self.n_srcs, 2, n_frames, n_freqs]) + mask = new_complex_like(batch0, (mask[:, :, 0], mask[:, :, 1])) + + batch = mask * batch0 + + # batch = batch.view([n_batch, self.n_srcs, 2, n_frames, n_freqs]) + # batch = new_complex_like(batch0, (batch[:, :, 0], batch[:, :, 1])) + + batch = self.dec(batch.view(-1, n_frames, n_freqs), ilens)[0] # [B, n_srcs, -1] + + batch = self.pad2(batch.view([n_batch, self.num_spk, -1]), n_samples) + + batch = batch * mix_std_ # reverse the RMS normalization + + batch = [batch[:, src] for src in range(self.num_spk)] + + return batch, ilens, OrderedDict() + + @property + def num_spk(self): + return self.n_srcs + + @staticmethod + def pad2(input_tensor, target_len): + input_tensor = torch.nn.functional.pad( + input_tensor, (0, target_len - input_tensor.shape[-1]) + ) + return input_tensor + + +class GridNetBlock(nn.Module): + def __getitem__(self, key): + return getattr(self, key) + + def __init__( + self, + emb_dim, + emb_ks, + emb_hs, + n_freqs, + hidden_channels, + n_head=4, + approx_qk_dim=512, + activation="prelu", + eps=1e-5, + ): + super().__init__() + + in_channels = emb_dim * emb_ks + + self.intra_norm = LayerNormalization4D(emb_dim, eps=eps) + self.intra_rnn = nn.LSTM( + in_channels, hidden_channels, 1, batch_first=True, bidirectional=True + ) + self.intra_linear = nn.ConvTranspose1d( + hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs + ) + + self.inter_norm = LayerNormalization4D(emb_dim, eps=eps) + self.inter_rnn = nn.LSTM( + in_channels, hidden_channels, 1, batch_first=True, bidirectional=True + ) + self.inter_linear = nn.ConvTranspose1d( + hidden_channels * 2, emb_dim, emb_ks, stride=emb_hs + ) + + E = math.ceil( + approx_qk_dim * 1.0 / n_freqs + ) # approx_qk_dim is only approximate + assert emb_dim % n_head == 0 + for ii in range(n_head): + self.add_module( + "attn_conv_Q_%d" % ii, + nn.Sequential( + nn.Conv2d(emb_dim, E, 1), + get_layer(activation)(), + LayerNormalization4DCF((E, n_freqs), eps=eps), + ), + ) + self.add_module( + "attn_conv_K_%d" % ii, + nn.Sequential( + nn.Conv2d(emb_dim, E, 1), + get_layer(activation)(), + LayerNormalization4DCF((E, n_freqs), eps=eps), + ), + ) + self.add_module( + "attn_conv_V_%d" % ii, + nn.Sequential( + nn.Conv2d(emb_dim, emb_dim // n_head, 1), + get_layer(activation)(), + LayerNormalization4DCF((emb_dim // n_head, n_freqs), eps=eps), + ), + ) + self.add_module( + "attn_concat_proj", + nn.Sequential( + nn.Conv2d(emb_dim, emb_dim, 1), + get_layer(activation)(), + LayerNormalization4DCF((emb_dim, n_freqs), eps=eps), + ), + ) + + self.emb_dim = emb_dim + self.emb_ks = emb_ks + self.emb_hs = emb_hs + self.n_head = n_head + + def forward(self, x): + """GridNetBlock Forward. + + Args: + x: [B, C, T, Q] + out: [B, C, T, Q] + """ + B, C, old_T, old_Q = x.shape + T = math.ceil((old_T - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks + Q = math.ceil((old_Q - self.emb_ks) / self.emb_hs) * self.emb_hs + self.emb_ks + x = F.pad(x, (0, Q - old_Q, 0, T - old_T)) + + # intra RNN + input_ = x + intra_rnn = self.intra_norm(input_) # [B, C, T, Q] + intra_rnn = ( + intra_rnn.transpose(1, 2).contiguous().view(B * T, C, Q) + ) # [BT, C, Q] + intra_rnn = F.unfold( + intra_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) + ) # [BT, C*emb_ks, -1] + intra_rnn = intra_rnn.transpose(1, 2) # [BT, -1, C*emb_ks] + intra_rnn, _ = self.intra_rnn(intra_rnn) # [BT, -1, H] + intra_rnn = intra_rnn.transpose(1, 2) # [BT, H, -1] + intra_rnn = self.intra_linear(intra_rnn) # [BT, C, Q] + intra_rnn = intra_rnn.view([B, T, C, Q]) + intra_rnn = intra_rnn.transpose(1, 2).contiguous() # [B, C, T, Q] + intra_rnn = intra_rnn + input_ # [B, C, T, Q] + + # inter RNN + input_ = intra_rnn + inter_rnn = self.inter_norm(input_) # [B, C, T, F] + inter_rnn = ( + inter_rnn.permute(0, 3, 1, 2).contiguous().view(B * Q, C, T) + ) # [BF, C, T] + inter_rnn = F.unfold( + inter_rnn[..., None], (self.emb_ks, 1), stride=(self.emb_hs, 1) + ) # [BF, C*emb_ks, -1] + inter_rnn = inter_rnn.transpose(1, 2) # [BF, -1, C*emb_ks] + inter_rnn, _ = self.inter_rnn(inter_rnn) # [BF, -1, H] + inter_rnn = inter_rnn.transpose(1, 2) # [BF, H, -1] + inter_rnn = self.inter_linear(inter_rnn) # [BF, C, T] + inter_rnn = inter_rnn.view([B, Q, C, T]) + inter_rnn = inter_rnn.permute(0, 2, 3, 1).contiguous() # [B, C, T, Q] + inter_rnn = inter_rnn + input_ # [B, C, T, Q] + + # attention + inter_rnn = inter_rnn[..., :old_T, :old_Q] + batch = inter_rnn + + all_Q, all_K, all_V = [], [], [] + for ii in range(self.n_head): + all_Q.append(self["attn_conv_Q_%d" % ii](batch)) # [B, C, T, Q] + all_K.append(self["attn_conv_K_%d" % ii](batch)) # [B, C, T, Q] + all_V.append(self["attn_conv_V_%d" % ii](batch)) # [B, C, T, Q] + + Q = torch.cat(all_Q, dim=0) # [B', C, T, Q] + K = torch.cat(all_K, dim=0) # [B', C, T, Q] + V = torch.cat(all_V, dim=0) # [B', C, T, Q] + + Q = Q.transpose(1, 2) + Q = Q.flatten(start_dim=2) # [B', T, C*Q] + K = K.transpose(1, 2) + K = K.flatten(start_dim=2) # [B', T, C*Q] + V = V.transpose(1, 2) # [B', T, C, Q] + old_shape = V.shape + V = V.flatten(start_dim=2) # [B', T, C*Q] + emb_dim = Q.shape[-1] + + attn_mat = torch.matmul(Q, K.transpose(1, 2)) / (emb_dim**0.5) # [B', T, T] + attn_mat = F.softmax(attn_mat, dim=2) # [B', T, T] + V = torch.matmul(attn_mat, V) # [B', T, C*Q] + + V = V.reshape(old_shape) # [B', T, C, Q] + V = V.transpose(1, 2) # [B', C, T, Q] + emb_dim = V.shape[1] + + batch = V.view([self.n_head, B, emb_dim, old_T, -1]) # [n_head, B, C, T, Q]) + batch = batch.transpose(0, 1) # [B, n_head, C, T, Q]) + batch = batch.contiguous().view( + [B, self.n_head * emb_dim, old_T, -1] + ) # [B, C, T, Q]) + batch = self["attn_concat_proj"](batch) # [B, C, T, Q]) + + out = batch + inter_rnn + return out + + +class LayerNormalization4D(nn.Module): + def __init__(self, input_dimension, eps=1e-5): + super().__init__() + param_size = [1, input_dimension, 1, 1] + self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) + self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) + init.ones_(self.gamma) + init.zeros_(self.beta) + self.eps = eps + + def forward(self, x): + if x.ndim == 4: + _, C, _, _ = x.shape + stat_dim = (1,) + else: + raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) + mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,F] + std_ = torch.sqrt( + x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps + ) # [B,1,T,F] + x_hat = ((x - mu_) / std_) * self.gamma + self.beta + return x_hat + + +class LayerNormalization4DCF(nn.Module): + def __init__(self, input_dimension, eps=1e-5): + super().__init__() + assert len(input_dimension) == 2 + param_size = [1, input_dimension[0], 1, input_dimension[1]] + self.gamma = Parameter(torch.Tensor(*param_size).to(torch.float32)) + self.beta = Parameter(torch.Tensor(*param_size).to(torch.float32)) + init.ones_(self.gamma) + init.zeros_(self.beta) + self.eps = eps + + def forward(self, x): + if x.ndim == 4: + stat_dim = (1, 3) + else: + raise ValueError("Expect x to have 4 dimensions, but got {}".format(x.ndim)) + mu_ = x.mean(dim=stat_dim, keepdim=True) # [B,1,T,1] + std_ = torch.sqrt( + x.var(dim=stat_dim, unbiased=False, keepdim=True) + self.eps + ) # [B,1,T,F] + x_hat = ((x - mu_) / std_) * self.gamma + self.beta + return x_hat diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/tcn.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/tcn.py new file mode 100644 index 0000000000..1262861ee0 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/tcn.py @@ -0,0 +1,362 @@ +# Implementation of the TCN proposed in +# Luo. et al. "Conv-tasnet: Surpassing ideal time–frequency +# magnitude masking for speech separation." +# +# The code is based on: +# https://github.com/kaituoxu/Conv-TasNet/blob/master/src/conv_tasnet.py +# Licensed under MIT. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +EPS = torch.finfo(torch.get_default_dtype()).eps + +class TemporalConvNet(nn.Module): + def __init__( + self, + N, + B, + H, + P, + X, + R, + C, + Sc=None, + out_channel=None, + norm_type="gLN", + causal=False, + pre_mask_nonlinear="linear", + mask_nonlinear="relu", + ): + """Basic Module of tasnet. + + Args: + N: Number of filters in autoencoder + B: Number of channels in bottleneck 1 * 1-conv block + H: Number of channels in convolutional blocks + P: Kernel size in convolutional blocks + X: Number of convolutional blocks in each repeat + R: Number of repeats + C: Number of speakers + Sc: Number of channels in skip-connection paths' 1x1-conv blocks + out_channel: Number of output channels + if it is None, `N` will be used instead. + norm_type: BN, gLN, cLN + causal: causal or non-causal + pre_mask_nonlinear: the non-linear function before masknet + mask_nonlinear: use which non-linear function to generate mask + """ + super().__init__() + # Hyper-parameter + self.C = C + self.mask_nonlinear = mask_nonlinear + self.skip_connection = Sc is not None + self.out_channel = N if out_channel is None else out_channel + if self.skip_connection: + assert Sc == B, (Sc, B) + # Components + # [M, N, K] -> [M, N, K] + layer_norm = ChannelwiseLayerNorm(N) + # [M, N, K] -> [M, B, K] + bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) + # [M, B, K] -> [M, B, K] + repeats = [] + + self.receptive_field = 0 + for r in range(R): + blocks = [] + for x in range(X): + dilation = 2**x + if r == 0 and x == 0: + self.receptive_field += P + else: + self.receptive_field += (P - 1) * dilation + padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 + blocks += [ + TemporalBlock( + B, + H, + Sc, + P, + stride=1, + padding=padding, + dilation=dilation, + norm_type=norm_type, + causal=causal, + ) + ] + repeats += [nn.Sequential(*blocks)] + temporal_conv_net = nn.Sequential(*repeats) + # [M, B, K] -> [M, C*N, K] + mask_conv1x1 = nn.Conv1d(B, C * self.out_channel, 1, bias=False) + # Put together (for compatibility with older versions) + if pre_mask_nonlinear == "linear": + self.network = nn.Sequential( + layer_norm, bottleneck_conv1x1, temporal_conv_net, mask_conv1x1 + ) + else: + activ = { + "prelu": nn.PReLU(), + "relu": nn.ReLU(), + "tanh": nn.Tanh(), + "sigmoid": nn.Sigmoid(), + }[pre_mask_nonlinear] + self.network = nn.Sequential( + layer_norm, bottleneck_conv1x1, temporal_conv_net, activ, mask_conv1x1 + ) + + def forward(self, mixture_w): + """Keep this API same with TasNet. + + Args: + mixture_w: [M, N, K], M is batch size + + Returns: + est_mask: [M, C, N, K] + """ + M, N, K = mixture_w.size() + bottleneck = self.network[:2] + tcns = self.network[2] + masknet = self.network[3:] + output = bottleneck(mixture_w) + skip_conn = 0.0 + for block in tcns: + for layer in block: + tcn_out = layer(output) + if self.skip_connection: + residual, skip = tcn_out + skip_conn = skip_conn + skip + else: + residual = tcn_out + output = output + residual + # Use residual output when no skip connection + if self.skip_connection: + score = masknet(skip_conn) + else: + score = masknet(output) + + # [M, C*self.out_channel, K] -> [M, C, self.out_channel, K] + score = score.view(M, self.C, self.out_channel, K) + if self.mask_nonlinear == "softmax": + est_mask = torch.softmax(score, dim=1) + elif self.mask_nonlinear == "relu": + est_mask = torch.relu(score) + elif self.mask_nonlinear == "sigmoid": + est_mask = torch.sigmoid(score) + elif self.mask_nonlinear == "tanh": + est_mask = torch.tanh(score) + elif self.mask_nonlinear == "linear": + est_mask = score + else: + raise ValueError("Unsupported mask non-linear function") + return est_mask + +class TemporalBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False, + ): + super().__init__() + self.skip_connection = skip_channels is not None + # [M, B, K] -> [M, H, K] + conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) + prelu = nn.PReLU() + norm = choose_norm(norm_type, out_channels) + # [M, H, K] -> [M, B, K] + dsconv = DepthwiseSeparableConv( + out_channels, + in_channels, + skip_channels, + kernel_size, + stride, + padding, + dilation, + norm_type, + causal, + ) + # Put together + self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) + + def forward(self, x): + """Forward. + + Args: + x: [M, B, K] + + Returns: + [M, B, K] + """ + if self.skip_connection: + res_out, skip_out = self.net(x) + return res_out, skip_out + else: + res_out = self.net(x) + return res_out + +def choose_norm(norm_type, channel_size, shape="BDT"): + """The input of normalization will be (M, C, K), where M is batch size. + + C is channel size and K is sequence length. + """ + if norm_type == "gLN": + return GlobalLayerNorm(channel_size, shape=shape) + elif norm_type == "cLN": + return ChannelwiseLayerNorm(channel_size, shape=shape) + elif norm_type == "BN": + # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics + # along M and K, so this BN usage is right. + return nn.BatchNorm1d(channel_size) + elif norm_type == "GN": + return nn.GroupNorm(1, channel_size, eps=1e-8) + else: + raise ValueError("Unsupported normalization type") + + +class ChannelwiseLayerNorm(nn.Module): + """Channel-wise Layer Normalization (cLN).""" + + def __init__(self, channel_size, shape="BDT"): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + assert shape in ["BDT", "BTD"] + self.shape = shape + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, y): + """Forward. + + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + + Returns: + cLN_y: [M, N, K] + """ + + assert y.dim() == 3 + + if self.shape == "BTD": + y = y.transpose(1, 2).contiguous() + + mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] + var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] + cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == "BTD": + cLN_y = cLN_y.transpose(1, 2).contiguous() + + return cLN_y + +class DepthwiseSeparableConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels, + kernel_size, + stride, + padding, + dilation, + norm_type="gLN", + causal=False, + ): + super().__init__() + # Use `groups` option to implement depthwise convolution + # [M, H, K] -> [M, H, K] + depthwise_conv = nn.Conv1d( + in_channels, + in_channels, + kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=False, + ) + if causal: + chomp = Chomp1d(padding) + prelu = nn.PReLU() + norm = choose_norm(norm_type, in_channels) + # [M, H, K] -> [M, B, K] + pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) + # Put together + if causal: + self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, pointwise_conv) + else: + self.net = nn.Sequential(depthwise_conv, prelu, norm, pointwise_conv) + + # skip connection + if skip_channels is not None: + self.skip_conv = nn.Conv1d(in_channels, skip_channels, 1, bias=False) + else: + self.skip_conv = None + + def forward(self, x): + """Forward. + + Args: + x: [M, H, K] + + Returns: + res_out: [M, B, K] + skip_out: [M, Sc, K] + """ + shared_block = self.net[:-1] + shared = shared_block(x) + res_out = self.net[-1](shared) + if self.skip_conv is None: + return res_out + skip_out = self.skip_conv(shared) + return res_out, skip_out + + + +class GlobalLayerNorm(nn.Module): + """Global Layer Normalization (gLN).""" + + def __init__(self, channel_size, shape="BDT"): + super().__init__() + self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.beta = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] + self.reset_parameters() + assert shape in ["BDT", "BTD"] + self.shape = shape + + def reset_parameters(self): + self.gamma.data.fill_(1) + self.beta.data.zero_() + + @torch.cuda.amp.autocast(enabled=False) + def forward(self, y): + """Forward. + + Args: + y: [M, N, K], M is batch size, N is channel size, K is length + + Returns: + gLN_y: [M, N, K] + """ + if self.shape == "BTD": + y = y.transpose(1, 2).contiguous() + + mean = y.mean(dim=(1, 2), keepdim=True) # [M, 1, 1] + var = (torch.pow(y - mean, 2)).mean(dim=(1, 2), keepdim=True) + gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta + + if self.shape == "BTD": + gLN_y = gLN_y.transpose(1, 2).contiguous() + return gLN_y \ No newline at end of file diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh/wrappers.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh/wrappers.py new file mode 100644 index 0000000000..7b4e978ce2 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh/wrappers.py @@ -0,0 +1,177 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple +from collections import defaultdict +from itertools import permutations + +import torch + +from examples.speech_synthesis.preprocessing.tfgridnet.enh.loss_criterion import AbsEnhLoss + +class AbsLossWrapper(torch.nn.Module, ABC): + """Base class for all Enhancement loss wrapper modules.""" + + # The weight for the current loss in the multi-task learning. + # The overall training target will be combined as: + # loss = weight_1 * loss_1 + ... + weight_N * loss_N + weight = 1.0 + + @abstractmethod + def forward( + self, + ref: List, + inf: List, + others: Dict, + ) -> Tuple[torch.Tensor, Dict, Dict]: + raise NotImplementedError + +class PITSolver(AbsLossWrapper): + def __init__( + self, + criterion: AbsEnhLoss, + weight=1.0, + independent_perm=True, + flexible_numspk=False, + ): + """Permutation Invariant Training Solver. + + Args: + criterion (AbsEnhLoss): an instance of AbsEnhLoss + weight (float): weight (between 0 and 1) of current loss + for multi-task learning. + independent_perm (bool): + If True, PIT will be performed in forward to find the best permutation; + If False, the permutation from the last LossWrapper output will be + inherited. + NOTE (wangyou): You should be careful about the ordering of loss + wrappers defined in the yaml config, if this argument is False. + flexible_numspk (bool): + If True, num_spk will be taken from inf to handle flexible numbers of + speakers. This is because ref may include dummy data in this case. + """ + super().__init__() + self.criterion = criterion + self.weight = weight + self.independent_perm = independent_perm + self.flexible_numspk = flexible_numspk + + def forward(self, ref, inf, others={}): + """PITSolver forward. + + Args: + ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk + inf (List[torch.Tensor]): [(batch, ...), ...] + + Returns: + loss: (torch.Tensor): minimum loss with the best permutation + stats: dict, for collecting training status + others: dict, in this PIT solver, permutation order will be returned + """ + perm = others["perm"] if "perm" in others else None + + if not self.flexible_numspk: + assert len(ref) == len(inf), (len(ref), len(inf)) + num_spk = len(ref) + else: + num_spk = len(inf) + + stats = defaultdict(list) + + def pre_hook(func, *args, **kwargs): + ret = func(*args, **kwargs) + for k, v in getattr(self.criterion, "stats", {}).items(): + stats[k].append(v) + return ret + + def pair_loss(permutation): + return sum( + [ + pre_hook(self.criterion, ref[s], inf[t]) + for s, t in enumerate(permutation) + ] + ) / len(permutation) + + if self.independent_perm or perm is None: + # computate permuatation independently + device = ref[0].device + all_permutations = list(permutations(range(num_spk))) + losses = torch.stack([pair_loss(p) for p in all_permutations], dim=1) + loss, perm_ = torch.min(losses, dim=1) + perm = torch.index_select( + torch.tensor(all_permutations, device=device, dtype=torch.long), + 0, + perm_, + ) + # remove stats from unused permutations + for k, v in stats.items(): + # (B, num_spk * len(all_permutations), ...) + new_v = torch.stack(v, dim=1) + B, L, *rest = new_v.shape + assert L == num_spk * len(all_permutations), (L, num_spk) + new_v = new_v.view(B, L // num_spk, num_spk, *rest).mean(2) + if new_v.dim() > 2: + shapes = [1 for _ in rest] + perm0 = perm_.view(perm_.shape[0], 1, *shapes).expand(-1, -1, *rest) + else: + perm0 = perm_.unsqueeze(1) + stats[k] = new_v.gather(1, perm0.to(device=new_v.device)).unbind(1) + else: + loss = torch.tensor( + [ + torch.tensor( + [ + pre_hook( + self.criterion, + ref[s][batch].unsqueeze(0), + inf[t][batch].unsqueeze(0), + ) + for s, t in enumerate(p) + ] + ).mean() + for batch, p in enumerate(perm) + ] + ) + + loss = loss.mean() + + for k, v in stats.items(): + stats[k] = torch.stack(v, dim=1).mean() + stats[self.criterion.name] = loss.detach() + + return loss.mean(), dict(stats), {"perm": perm} + + +class FixedOrderSolver(AbsLossWrapper): + def __init__(self, criterion: AbsEnhLoss, weight=1.0): + super().__init__() + self.criterion = criterion + self.weight = weight + + def forward(self, ref, inf, others={}): + """An naive fixed-order solver + + Args: + ref (List[torch.Tensor]): [(batch, ...), ...] x n_spk + inf (List[torch.Tensor]): [(batch, ...), ...] + + Returns: + loss: (torch.Tensor): minimum loss with the best permutation + stats: dict, for collecting training status + others: reserved + """ + assert len(ref) == len(inf), (len(ref), len(inf)) + num_spk = len(ref) + + loss = 0.0 + stats = defaultdict(list) + for r, i in zip(ref, inf): + loss += torch.mean(self.criterion(r, i)) / num_spk + for k, v in getattr(self.criterion, "stats", {}).items(): + stats[k].append(v) + + for k, v in stats.items(): + stats[k] = torch.stack(v, dim=1).mean() + stats[self.criterion.name] = loss.detach() + + perm = torch.arange(num_spk).unsqueeze(0).repeat(ref[0].size(0), 1) + return loss.mean(), dict(stats), {"perm": perm} + diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/enh_inference.py b/examples/speech_synthesis/preprocessing/tfgridnet/enh_inference.py new file mode 100644 index 0000000000..f2fec13f83 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/enh_inference.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python3 +import logging + +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple, Union +import numpy as np +import torch +from typeguard import check_argument_types + + +from examples.speech_synthesis.preprocessing.tfgridnet.tasks.enh import EnhancementTask +from examples.speech_synthesis.preprocessing.tfgridnet.torch_utils.device_functions import to_device + + +def get_train_config(train_config, model_file=None): + if train_config is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'train_config' is not specified." + ) + train_config = Path(model_file).parent / "config.yaml" + else: + train_config = Path(train_config) + return train_config + + +def recursive_dict_update(dict_org, dict_patch, verbose=False, log_prefix=""): + """Update `dict_org` with `dict_patch` in-place recursively.""" + for key, value in dict_patch.items(): + if key not in dict_org: + if verbose: + logging.info( + "Overwriting config: [{}{}]: None -> {}".format( + log_prefix, key, value + ) + ) + dict_org[key] = value + elif isinstance(value, dict): + recursive_dict_update( + dict_org[key], value, verbose=verbose, log_prefix=f"{key}." + ) + else: + if verbose and dict_org[key] != value: + logging.info( + "Overwriting config: [{}{}]: {} -> {}".format( + log_prefix, key, dict_org[key], value + ) + ) + dict_org[key] = value + + +def build_model_from_args_and_file(task, args, model_file, device): + model = task.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + model.to(device) + if model_file is not None: + if device == "cuda": + # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 + # in PyTorch<=1.4 + device = f"cuda:{torch.cuda.current_device()}" + model.load_state_dict(torch.load(model_file, map_location=device)) + return model + + +class SeparateSpeech: + """SeparateSpeech class + + Examples: + >>> import soundfile + >>> separate_speech = SeparateSpeech("enh_config.yml", "enh.pth") + >>> audio, rate = soundfile.read("speech.wav") + >>> separate_speech(audio) + [separated_audio1, separated_audio2, ...] + + """ + + def __init__( + self, + train_config: Union[Path, str] = None, + model_file: Union[Path, str] = None, + inference_config: Union[Path, str] = None, + segment_size: Optional[float] = None, + hop_size: Optional[float] = None, + normalize_segment_scale: bool = False, + show_progressbar: bool = False, + ref_channel: Optional[int] = None, + normalize_output_wav: bool = False, + device: str = "cpu", + dtype: str = "float32", + enh_s2t_task: bool = False, + ): + assert check_argument_types() + + task = EnhancementTask if not enh_s2t_task else EnhS2TTask + + # 1. Build Enh model + + if inference_config is None: + enh_model, enh_train_args = task.build_model_from_file( + train_config, model_file, device + ) + else: + # Overwrite model attributes + train_config = get_train_config(train_config, model_file=model_file) + with train_config.open("r", encoding="utf-8") as f: + train_args = yaml.safe_load(f) + + with Path(inference_config).open("r", encoding="utf-8") as f: + infer_args = yaml.safe_load(f) + + if enh_s2t_task: + arg_list = ("enh_encoder", "enh_separator", "enh_decoder") + else: + arg_list = ("encoder", "separator", "decoder") + supported_keys = list(chain(*[[k, k + "_conf"] for k in arg_list])) + for k in infer_args.keys(): + if k not in supported_keys: + raise ValueError( + "Only the following top-level keys are supported: %s" + % ", ".join(supported_keys) + ) + + recursive_dict_update(train_args, infer_args, verbose=True) + enh_train_args = argparse.Namespace(**train_args) + enh_model = build_model_from_args_and_file( + task, enh_train_args, model_file, device + ) + + if enh_s2t_task: + enh_model = enh_model.enh_model + enh_model.to(dtype=getattr(torch, dtype)).eval() + + self.device = device + self.dtype = dtype + self.enh_train_args = enh_train_args + self.enh_model = enh_model + + # only used when processing long speech, i.e. + # segment_size is not None and hop_size is not None + self.segment_size = segment_size + self.hop_size = hop_size + self.normalize_segment_scale = normalize_segment_scale + self.normalize_output_wav = normalize_output_wav + self.show_progressbar = show_progressbar + + self.num_spk = enh_model.num_spk + task = "enhancement" if self.num_spk == 1 else "separation" + + # reference channel for processing multi-channel speech + if ref_channel is not None: + logging.info( + "Overwrite enh_model.separator.ref_channel with {}".format(ref_channel) + ) + enh_model.separator.ref_channel = ref_channel + if hasattr(enh_model.separator, "beamformer"): + enh_model.separator.beamformer.ref_channel = ref_channel + self.ref_channel = ref_channel + else: + self.ref_channel = enh_model.ref_channel + + self.segmenting = segment_size is not None and hop_size is not None + if self.segmenting: + logging.info("Perform segment-wise speech %s" % task) + logging.info( + "Segment length = {} sec, hop length = {} sec".format( + segment_size, hop_size + ) + ) + else: + logging.info("Perform direct speech %s on the input" % task) + + @torch.no_grad() + def __call__( + self, speech_mix: Union[torch.Tensor, np.ndarray], fs: int = 8000, **kwargs + ) -> List[torch.Tensor]: + """Inference + + Args: + speech_mix: Input speech data (Batch, Nsamples [, Channels]) + fs: sample rate + Returns: + [separated_audio1, separated_audio2, ...] + + """ + assert check_argument_types() + + # Input as audio signal + if isinstance(speech_mix, np.ndarray): + speech_mix = torch.as_tensor(speech_mix) + + assert speech_mix.dim() > 1, speech_mix.size() + batch_size = speech_mix.size(0) + speech_mix = speech_mix.to(getattr(torch, self.dtype)) + # lengths: (B,) + lengths = speech_mix.new_full( + [batch_size], dtype=torch.long, fill_value=speech_mix.size(1) + ) + + # a. To device + speech_mix = to_device(speech_mix, device=self.device) + lengths = to_device(lengths, device=self.device) + + ################################### + # Normalize the signal variance + if getattr(self.enh_model, "normalize_variance_per_ch", False): + dim = 1 + mix_std_ = torch.std(speech_mix, dim=dim, keepdim=True) + speech_mix = speech_mix / mix_std_ # RMS normalization + elif getattr(self.enh_model, "normalize_variance", False): + if speech_mix.ndim > 2: + dim = (1, 2) + else: + dim = 1 + mix_std_ = torch.std(speech_mix, dim=dim, keepdim=True) + speech_mix = speech_mix / mix_std_ # RMS normalization + + category = kwargs.get("utt2category", None) + if ( + self.enh_model.categories + and category is not None + and category[0].item() not in self.enh_model.categories + ): + raise ValueError(f"Category '{category}' is not listed in self.categories") + + additional = {} + if category is not None: + cat = self.enh_model.categories[category[0].item()] + print(f"category: {cat}", flush=True) + if cat.endswith("_reverb"): + additional["mode"] = "dereverb" + else: + additional["mode"] = "no_dereverb" + + if self.segmenting and lengths[0] > self.segment_size * fs: + # Segment-wise speech enhancement/separation + overlap_length = int(np.round(fs * (self.segment_size - self.hop_size))) + num_segments = int( + np.ceil((speech_mix.size(1) - overlap_length) / (self.hop_size * fs)) + ) + t = T = int(self.segment_size * fs) + pad_shape = speech_mix[:, :T].shape + enh_waves = [] + range_ = trange if self.show_progressbar else range + for i in range_(num_segments): + st = int(i * self.hop_size * fs) + en = st + T + if en >= lengths[0]: + # en - st < T (last segment) + en = lengths[0] + speech_seg = speech_mix.new_zeros(pad_shape) + t = en - st + speech_seg[:, :t] = speech_mix[:, st:en] + else: + t = T + speech_seg = speech_mix[:, st:en] # B x T [x C] + + lengths_seg = speech_mix.new_full( + [batch_size], dtype=torch.long, fill_value=T + ) + # b. Enhancement/Separation Forward + feats, f_lens = self.enh_model.encoder(speech_seg, lengths_seg) + feats, _, _ = self.enh_model.separator(feats, f_lens, additional) + processed_wav = [ + self.enh_model.decoder(f, lengths_seg)[0] for f in feats + ] + if speech_seg.dim() > 2: + # multi-channel speech + speech_seg_ = speech_seg[:, self.ref_channel] + else: + speech_seg_ = speech_seg + + if self.normalize_segment_scale: + # normalize the scale to match the input mixture scale + mix_energy = torch.sqrt( + torch.mean(speech_seg_[:, :t].pow(2), dim=1, keepdim=True) + ) + enh_energy = torch.sqrt( + torch.mean( + sum(processed_wav)[:, :t].pow(2), dim=1, keepdim=True + ) + ) + processed_wav = [ + w * (mix_energy / enh_energy) for w in processed_wav + ] + # List[torch.Tensor(num_spk, B, T)] + enh_waves.append(torch.stack(processed_wav, dim=0)) + + # c. Stitch the enhanced segments together + waves = enh_waves[0] + for i in range(1, num_segments): + # permutation between separated streams in last and current segments + perm = self.cal_permumation( + waves[:, :, -overlap_length:], + enh_waves[i][:, :, :overlap_length], + criterion="si_snr", + ) + # repermute separated streams in current segment + for batch in range(batch_size): + enh_waves[i][:, batch] = enh_waves[i][perm[batch], batch] + + if i == num_segments - 1: + enh_waves[i][:, :, t:] = 0 + enh_waves_res_i = enh_waves[i][:, :, overlap_length:t] + else: + enh_waves_res_i = enh_waves[i][:, :, overlap_length:] + + # overlap-and-add (average over the overlapped part) + waves[:, :, -overlap_length:] = ( + waves[:, :, -overlap_length:] + enh_waves[i][:, :, :overlap_length] + ) / 2 + # concatenate the residual parts of the later segment + waves = torch.cat([waves, enh_waves_res_i], dim=2) + # ensure the stitched length is same as input + assert waves.size(2) == speech_mix.size(1), (waves.shape, speech_mix.shape) + waves = torch.unbind(waves, dim=0) + else: + # b. Enhancement/Separation Forward + feats, f_lens = self.enh_model.encoder(speech_mix, lengths) + feats, _, _ = self.enh_model.separator(feats, f_lens, additional) + waves = [self.enh_model.decoder(f, lengths)[0] for f in feats] + + ################################### + # De-normalize the signal variance + if getattr(self.enh_model, "normalize_variance_per_ch", False): + if mix_std_.ndim > 2: + mix_std_ = mix_std_[:, :, self.ref_channel] + waves = [w * mix_std_ for w in waves] + elif getattr(self.enh_model, "normalize_variance", False): + if mix_std_.ndim > 2: + mix_std_ = mix_std_.squeeze(2) + waves = [w * mix_std_ for w in waves] + + assert len(waves) == self.num_spk, len(waves) == self.num_spk + assert len(waves[0]) == batch_size, (len(waves[0]), batch_size) + if self.normalize_output_wav: + waves = [ + (w / abs(w).max(dim=1, keepdim=True)[0] * 0.9).cpu().numpy() + for w in waves + ] # list[(batch, sample)] + else: + waves = [w.cpu().numpy() for w in waves] + + return waves + + @torch.no_grad() + def cal_permumation(self, ref_wavs, enh_wavs, criterion="si_snr"): + """Calculate the permutation between seaprated streams in two adjacent segments. + + Args: + ref_wavs (List[torch.Tensor]): [(Batch, Nsamples)] + enh_wavs (List[torch.Tensor]): [(Batch, Nsamples)] + criterion (str): one of ("si_snr", "mse", "corr) + Returns: + perm (torch.Tensor): permutation for enh_wavs (Batch, num_spk) + """ + + criterion_class = {"si_snr": SISNRLoss, "mse": FrequencyDomainMSE}[criterion] + + pit_solver = PITSolver(criterion=criterion_class()) + + _, _, others = pit_solver(ref_wavs, enh_wavs) + perm = others["perm"] + return perm + + @staticmethod + def from_pretrained( + model_tag: Optional[str] = None, + **kwargs: Optional[Any], + ): + """Build SeparateSpeech instance from the pretrained model. + + Args: + model_tag (Optional[str]): Model tag of the pretrained models. + Currently, the tags of espnet_model_zoo are supported. + + Returns: + SeparateSpeech: SeparateSpeech instance. + + """ + if model_tag is not None: + try: + from espnet_model_zoo.downloader import ModelDownloader + + except ImportError: + logging.error( + "`espnet_model_zoo` is not installed. " + "Please install via `pip install -U espnet_model_zoo`." + ) + raise + d = ModelDownloader() + kwargs.update(**d.download_and_unpack(model_tag)) + + return SeparateSpeech(**kwargs) diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/layers/inversible_interface.py b/examples/speech_synthesis/preprocessing/tfgridnet/layers/inversible_interface.py new file mode 100644 index 0000000000..30874a87e8 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/layers/inversible_interface.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch + + +class InversibleInterface(ABC): + @abstractmethod + def inverse( + self, input: torch.Tensor, input_lengths: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + # return output, output_lengths + raise NotImplementedError diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/layers/nets_utils.py b/examples/speech_synthesis/preprocessing/tfgridnet/layers/nets_utils.py new file mode 100644 index 0000000000..437b87341b --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/layers/nets_utils.py @@ -0,0 +1,583 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + if isinstance(m, torch.nn.Module): + device = next(m.parameters()).device + elif isinstance(m, torch.Tensor): + device = m.device + else: + raise TypeError( + "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}" + ) + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, : xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + # If the input dimension is 2 or 3, + # then we use ESPnet-ONNX based implementation for tracable modeling. + # otherwise we use the traditional implementation for research use. + if isinstance(lengths, list): + logging.warning( + "Using make_pad_mask with a list of lengths is not tracable. " + + "If you try to trace this function with type(lengths) == list, " + + "please change the type of lengths to torch.LongTensor." + ) + + if ( + (xs is None or xs.dim() in (2, 3)) + and length_dim <= 2 + and (not isinstance(lengths, list) and lengths.dim() == 1) + ): + return _make_pad_mask_traceable(lengths, xs, length_dim, maxlen) + else: + return _make_pad_mask(lengths, xs, length_dim, maxlen) + + +def _make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + if not isinstance(lengths, list): + lengths = lengths.long().tolist() + + bs = int(len(lengths)) + if maxlen is None: + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + else: + assert xs is None, "When maxlen is specified, xs must not be specified." + assert maxlen >= int( + max(lengths) + ), f"maxlen {maxlen} must be >= max(lengths) {max(lengths)}" + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert ( + xs.size(0) == bs + ), f"The size of x.size(0) {xs.size(0)} must match the batch size {bs}" + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def _make_pad_mask_traceable(lengths, xs, length_dim, maxlen=None): + """ + Make mask tensor containing indices of padded part. + This is a simplified implementation of make_pad_mask without the xs input + that supports JIT tracing for applications like exporting models to ONNX. + Dimension length of xs should be 2 or 3 + This function will create torch.ones(maxlen, maxlen).triu(diagonal=1) and + select rows to create mask tensor. + """ + + if xs is None: + device = lengths.device + else: + device = xs.device + + if xs is not None and len(xs.shape) == 3: + if length_dim == 1: + lengths = lengths.unsqueeze(1).expand(*xs.transpose(1, 2).shape[:2]) + else: + # Then length_dim is 2 or -1. + if length_dim not in (-1, 2): + logging.warn( + f"Invalid length_dim {length_dim}." + + "We set it to -1, which is the default value." + ) + length_dim = -1 + lengths = lengths.unsqueeze(1).expand(*xs.shape[:2]) + + if maxlen is not None: + assert xs is None + assert maxlen >= lengths.max() + elif xs is not None: + maxlen = xs.shape[length_dim] + else: + maxlen = lengths.max() + + # clip max(length) to maxlen + lengths = torch.clamp(lengths, max=maxlen).type(torch.long) + + mask = torch.ones(maxlen + 1, maxlen + 1, dtype=torch.bool, device=device) + mask = triu_onnx(mask)[1:, :-1] # onnx cannot handle diagonal argument. + mask = mask[lengths - 1][..., :maxlen] + + if xs is not None and len(xs.shape) == 3 and length_dim == 1: + return mask.transpose(1, 2) + else: + return mask + + +def triu_onnx(x): + arange = torch.arange(x.size(0), device=x.device) + mask = arange.unsqueeze(-1).expand(-1, x.size(0)) <= arange + return x * mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) + ).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum( + pad_pred.masked_select(mask) == pad_targets.masked_select(mask) + ) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == "c": + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if "real" not in x or "imag" not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x["real"], x["imag"]) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ( + "x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x)) + ) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == "transformer": + return np.array([1]) + + elif mode == "mt" and arch == "rnn": + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int64) + logging.warning("Subsampling is not performed for machine translation.") + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif ( + (mode == "asr" and arch in ("rnn", "rnn-t")) + or (mode == "mt" and arch == "rnn") + or (mode == "st" and arch == "rnn") + ): + subsample = np.ones(train_args.elayers + 1, dtype=np.int64) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mix": + subsample = np.ones( + train_args.elayers_sd + train_args.elayers + 1, dtype=np.int64 + ) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range( + min(train_args.elayers_sd + train_args.elayers + 1, len(ss)) + ): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN." + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + return subsample + + elif mode == "asr" and arch == "rnn_mulenc": + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int64) + if train_args.etype[idx].endswith("p") and not train_args.etype[ + idx + ].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + "Encoder %d: Subsampling is not performed for vgg*. " + "It is performed in max pooling layers at CNN.", + idx + 1, + ) + logging.info("subsample: " + " ".join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch)) + + +def rename_state_dict( + old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor] +): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f"Rename: {old_prefix} -> {new_prefix}") + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v + + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from espnet.nets.pytorch_backend.conformer.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "tanh": torch.nn.Tanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": Swish, + } + + return activation_funcs[act]() diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/layers/stft.py b/examples/speech_synthesis/preprocessing/tfgridnet/layers/stft.py new file mode 100644 index 0000000000..e501ebfc30 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/layers/stft.py @@ -0,0 +1,215 @@ +from typing import Optional, Tuple, Union + +import librosa +import numpy as np +import torch +from torch_complex.tensor import ComplexTensor +from typeguard import check_argument_types + +from examples.speech_synthesis.preprocessing.tfgridnet.enh.layers_complex_utils import to_complex +from examples.speech_synthesis.preprocessing.tfgridnet.layers.inversible_interface import InversibleInterface +from examples.speech_synthesis.preprocessing.tfgridnet.layers.nets_utils import make_pad_mask + + +class Stft(torch.nn.Module, InversibleInterface): + def __init__( + self, + n_fft: int = 512, + win_length: int = None, + hop_length: int = 128, + window: Optional[str] = "hann", + center: bool = True, + normalized: bool = False, + onesided: bool = True, + ): + assert check_argument_types() + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.normalized = normalized + self.onesided = onesided + if window is not None and not hasattr(torch, f"{window}_window"): + raise ValueError(f"{window} window is not implemented") + self.window = window + + def extra_repr(self): + return ( + f"n_fft={self.n_fft}, " + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + f"normalized={self.normalized}, " + f"onesided={self.onesided}" + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + if input.dim() == 3: + multi_channel = True + # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) + input = input.transpose(1, 2).reshape(-1, input.size(1)) + else: + multi_channel = False + + # NOTE(kamo): + # The default behaviour of torch.stft is compatible with librosa.stft + # about padding and scaling. + # Note that it's different from scipy.signal.stft + + # output: (Batch, Freq, Frames, 2=real_imag) + # or (Batch, Channel, Freq, Frames, 2=real_imag) + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + window = window_func( + self.win_length, dtype=input.dtype, device=input.device + ) + else: + window = None + + # For the compatibility of ARM devices, which do not support + # torch.stft() due to the lack of MKL (on older pytorch versions), + # there is an alternative replacement implementation with librosa. + # Note: pytorch >= 1.10.0 now has native support for FFT and STFT + # on all cpu targets including ARM. + if input.is_cuda or torch.backends.mkl.is_available(): + stft_kwargs = dict( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + window=window, + normalized=self.normalized, + onesided=self.onesided, + ) + stft_kwargs["return_complex"] = True + output = torch.stft(input, **stft_kwargs) + output = torch.view_as_real(output) + else: + if self.training: + raise NotImplementedError( + "stft is implemented with librosa on this device, which does not " + "support the training mode." + ) + + # use stft_kwargs to flexibly control different PyTorch versions' kwargs + # note: librosa does not support a win_length that is < n_ftt + # but the window can be manually padded (see below). + stft_kwargs = dict( + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.hop_length, + center=self.center, + window=window, + pad_mode="reflect", + ) + + if window is not None: + # pad the given window to n_fft + n_pad_left = (self.n_fft - window.shape[0]) // 2 + n_pad_right = self.n_fft - window.shape[0] - n_pad_left + stft_kwargs["window"] = torch.cat( + [torch.zeros(n_pad_left), window, torch.zeros(n_pad_right)], 0 + ).numpy() + else: + win_length = ( + self.win_length if self.win_length is not None else self.n_fft + ) + stft_kwargs["window"] = torch.ones(win_length) + + output = [] + # iterate over istances in a batch + for i, instance in enumerate(input): + stft = librosa.stft(input[i].numpy(), **stft_kwargs) + output.append(torch.tensor(np.stack([stft.real, stft.imag], -1))) + output = torch.stack(output, 0) + if not self.onesided: + len_conj = self.n_fft - output.shape[1] + conj = output[:, 1 : 1 + len_conj].flip(1) + conj[:, :, :, -1].data *= -1 + output = torch.cat([output, conj], 1) + if self.normalized: + output = output * (stft_kwargs["window"].shape[0] ** (-0.5)) + + # output: (Batch, Freq, Frames, 2=real_imag) + # -> (Batch, Frames, Freq, 2=real_imag) + output = output.transpose(1, 2) + if multi_channel: + # output: (Batch * Channel, Frames, Freq, 2=real_imag) + # -> (Batch, Frame, Channel, Freq, 2=real_imag) + output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( + 1, 2 + ) + + if ilens is not None: + if self.center: + pad = self.n_fft // 2 + ilens = ilens + 2 * pad + + olens = ( + torch.div(ilens - self.n_fft, self.hop_length, rounding_mode="trunc") + + 1 + ) + output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) + else: + olens = None + + return output, olens + + def inverse( + self, input: Union[torch.Tensor, ComplexTensor], ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Inverse STFT. + + Args: + input: Tensor(batch, T, F, 2) or ComplexTensor(batch, T, F) + ilens: (batch,) + Returns: + wavs: (batch, samples) + ilens: (batch,) + """ + input = to_complex(input) + + if self.window is not None: + window_func = getattr(torch, f"{self.window}_window") + datatype = input.real.dtype + window = window_func(self.win_length, dtype=datatype, device=input.device) + else: + window = None + + # if is_complex(input): + # input = torch.stack([input.real, input.imag], dim=-1) + # input = torch.view_as_complex(input) + # elif input.shape[-1] != 2: + # raise TypeError("Invalid input type") + input = input.transpose(1, 2) + + wavs = torch.functional.istft( + input, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=window, + center=self.center, + normalized=self.normalized, + onesided=self.onesided, + length=ilens.max() if ilens is not None else ilens, + return_complex=False, + ) + + return wavs, ilens diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/mask.py b/examples/speech_synthesis/preprocessing/tfgridnet/mask.py new file mode 100644 index 0000000000..0232b7d2ef --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/mask.py @@ -0,0 +1,129 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_complex.tensor import ComplexTensor + +class AbsMask(torch.nn.Module, ABC): + @property + @abstractmethod + def max_num_spk(self) -> int: + raise NotImplementedError + + @abstractmethod + def forward( + self, + input, + ilens, + bottleneck_feat, + num_spk, + ) -> Tuple[Tuple[torch.Tensor], torch.Tensor, OrderedDict]: + raise NotImplementedError + + +# This is an implementation of the multiple 1x1 convolution layer architecture +# in https://arxiv.org/pdf/2203.17068.pdf + +class MultiMask(AbsMask): + def __init__( + self, + input_dim: int, + bottleneck_dim: int = 128, + max_num_spk: int = 3, + mask_nonlinear="relu", + ): + """Multiple 1x1 convolution layer Module. + + This module corresponds to the final 1x1 conv block and + non-linear function in TCNSeparator. + This module has multiple 1x1 conv blocks. One of them is selected + according to the given num_spk to handle flexible num_spk. + + Args: + input_dim: Number of filters in autoencoder + bottleneck_dim: Number of channels in bottleneck 1 * 1-conv block + max_num_spk: Number of mask_conv1x1 modules + (>= Max number of speakers in the dataset) + mask_nonlinear: use which non-linear function to generate mask + """ + super().__init__() + # Hyper-parameter + self._max_num_spk = max_num_spk + self.mask_nonlinear = mask_nonlinear + # [M, B, K] -> [M, C*N, K] + self.mask_conv1x1 = nn.ModuleList() + for z in range(1, max_num_spk + 1): + self.mask_conv1x1.append( + nn.Conv1d(bottleneck_dim, z * input_dim, 1, bias=False) + ) + + @property + def max_num_spk(self) -> int: + return self._max_num_spk + + def forward( + self, + input: Union[torch.Tensor, ComplexTensor], + ilens: torch.Tensor, + bottleneck_feat: torch.Tensor, + num_spk: int, + ) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]: + """Keep this API same with TasNet. + + Args: + input: [M, K, N], M is batch size + ilens (torch.Tensor): (M,) + bottleneck_feat: [M, K, B] + num_spk: number of speakers + (Training: oracle, + Inference: estimated by other module (e.g, EEND-EDA)) + + Returns: + masked (List[Union(torch.Tensor, ComplexTensor)]): [(M, K, N), ...] + ilens (torch.Tensor): (M,) + others predicted data, e.g. masks: OrderedDict[ + 'mask_spk1': torch.Tensor(Batch, Frames, Freq), + 'mask_spk2': torch.Tensor(Batch, Frames, Freq), + ... + 'mask_spkn': torch.Tensor(Batch, Frames, Freq), + ] + + """ + M, K, N = input.size() + bottleneck_feat = bottleneck_feat.transpose(1, 2) # [M, B, K] + score = self.mask_conv1x1[num_spk - 1]( + bottleneck_feat + ) # [M, B, K] -> [M, num_spk*N, K] + # add other outputs of the module list with factor 0.0 + # to enable distributed training + for z in range(self._max_num_spk): + if z != num_spk - 1: + score += 0.0 * F.interpolate( + self.mask_conv1x1[z](bottleneck_feat).transpose(1, 2), + size=num_spk * N, + ).transpose(1, 2) + score = score.view(M, num_spk, N, K) # [M, num_spk*N, K] -> [M, num_spk, N, K] + if self.mask_nonlinear == "softmax": + est_mask = F.softmax(score, dim=1) + elif self.mask_nonlinear == "relu": + est_mask = F.relu(score) + elif self.mask_nonlinear == "sigmoid": + est_mask = torch.sigmoid(score) + elif self.mask_nonlinear == "tanh": + est_mask = torch.tanh(score) + else: + raise ValueError("Unsupported mask non-linear function") + + masks = est_mask.transpose(2, 3) # [M, num_spk, K, N] + masks = masks.unbind(dim=1) # List[M, K, N] + + masked = [input * m for m in masks] + + others = OrderedDict( + zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks) + ) + + return masked, ilens, others \ No newline at end of file diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/tasks/abs_task.py b/examples/speech_synthesis/preprocessing/tfgridnet/tasks/abs_task.py new file mode 100644 index 0000000000..910eb5bde6 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/tasks/abs_task.py @@ -0,0 +1,296 @@ +"""Abstract task module.""" +import argparse +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.multiprocessing +import torch.nn +import torch.optim +import yaml +from packaging.version import parse as V +from torch.utils.data import DataLoader +from typeguard import check_argument_types, check_return_type + +from examples.speech_synthesis.preprocessing.tfgridnet.train.espnet_model import AbsESPnetModel +from examples.speech_synthesis.preprocessing.tfgridnet.train.class_choices import ClassChoices + +try: + import wandb +except Exception: + wandb = None + +if V(torch.__version__) >= V("1.5.0"): + from torch.multiprocessing.spawn import ProcessContext +else: + from torch.multiprocessing.spawn import SpawnContext as ProcessContext + + +optim_classes = dict( + adam=torch.optim.Adam, + adamw=torch.optim.AdamW, + #sgd=SGD, + adadelta=torch.optim.Adadelta, + adagrad=torch.optim.Adagrad, + adamax=torch.optim.Adamax, + asgd=torch.optim.ASGD, + lbfgs=torch.optim.LBFGS, + rmsprop=torch.optim.RMSprop, + rprop=torch.optim.Rprop, +) +if V(torch.__version__) >= V("1.10.0"): + # From 1.10.0, RAdam is officially supported + optim_classes.update( + radam=torch.optim.RAdam, + ) +try: + import torch_optimizer + + optim_classes.update( + accagd=torch_optimizer.AccSGD, + adabound=torch_optimizer.AdaBound, + adamod=torch_optimizer.AdaMod, + diffgrad=torch_optimizer.DiffGrad, + lamb=torch_optimizer.Lamb, + novograd=torch_optimizer.NovoGrad, + pid=torch_optimizer.PID, + # torch_optimizer<=0.0.1a10 doesn't support + # qhadam=torch_optimizer.QHAdam, + qhm=torch_optimizer.QHM, + sgdw=torch_optimizer.SGDW, + yogi=torch_optimizer.Yogi, + ) + if V(torch_optimizer.__version__) < V("0.2.0"): + # From 0.2.0, RAdam is dropped + optim_classes.update( + radam=torch_optimizer.RAdam, + ) + del torch_optimizer +except ImportError: + pass +try: + import apex + + optim_classes.update( + fusedadam=apex.optimizers.FusedAdam, + fusedlamb=apex.optimizers.FusedLAMB, + fusednovograd=apex.optimizers.FusedNovoGrad, + fusedsgd=apex.optimizers.FusedSGD, + ) + del apex +except ImportError: + pass +try: + import fairscale +except ImportError: + fairscale = None + + +scheduler_classes = dict( + ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, + lambdalr=torch.optim.lr_scheduler.LambdaLR, + steplr=torch.optim.lr_scheduler.StepLR, + multisteplr=torch.optim.lr_scheduler.MultiStepLR, + exponentiallr=torch.optim.lr_scheduler.ExponentialLR, + CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, + #noamlr=NoamLR, + #warmuplr=WarmupLR, + #warmupsteplr=WarmupStepLR, + #warmupReducelronplateau=WarmupReduceLROnPlateau, + cycliclr=torch.optim.lr_scheduler.CyclicLR, + onecyclelr=torch.optim.lr_scheduler.OneCycleLR, + CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, + #CosineAnnealingWarmupRestarts=CosineAnnealingWarmupRestarts, +) +# To lower keys +optim_classes = {k.lower(): v for k, v in optim_classes.items()} +scheduler_classes = {k.lower(): v for k, v in scheduler_classes.items()} + + +@dataclass +class IteratorOptions: + preprocess_fn: callable + collate_fn: callable + data_path_and_name_and_type: list + shape_files: list + batch_size: int + batch_bins: int + batch_type: str + max_cache_size: float + max_cache_fd: int + allow_multi_rates: bool + distributed: bool + num_batches: Optional[int] + num_iters_per_epoch: Optional[int] + train: bool + + +class AbsTask(ABC): + # Use @staticmethod, or @classmethod, + # instead of instance method to avoid God classes + + # If you need more than one optimizers, change this value in inheritance + num_optimizers: int = 1 + #trainer = Trainer + class_choices_list: List[ClassChoices] = [] + + def __init__(self): + raise RuntimeError("This class can't be instantiated.") + + @classmethod + @abstractmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + pass + + @classmethod + @abstractmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[[Sequence[Dict[str, np.ndarray]]], Dict[str, torch.Tensor]]: + """Return "collate_fn", which is a callable object and given to DataLoader. + + >>> from torch.utils.data import DataLoader + >>> loader = DataLoader(collate_fn=cls.build_collate_fn(args, train=True), ...) + + In many cases, you can use our common collate_fn. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + raise NotImplementedError + + @classmethod + @abstractmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Define the required names by Task + + This function is used by + >>> cls.check_task_requirements() + If your model is defined as following, + + >>> from espnet2.train.abs_espnet_model import AbsESPnetModel + >>> class Model(AbsESPnetModel): + ... def forward(self, input, output, opt=None): pass + + then "required_data_names" should be as + + >>> required_data_names = ('input', 'output') + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + """Define the optional names by Task + + This function is used by + >>> cls.check_task_requirements() + If your model is defined as follows, + + >>> from espnet2.train.abs_espnet_model import AbsESPnetModel + >>> class Model(AbsESPnetModel): + ... def forward(self, input, output, opt=None): pass + + then "optional_data_names" should be as + + >>> optional_data_names = ('opt',) + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel: + raise NotImplementedError + + + # ~~~~~~~~~ The methods below are mainly used for inference ~~~~~~~~~ + @classmethod + def build_model_from_file( + cls, + config_file: Union[Path, str] = None, + model_file: Union[Path, str] = None, + device: str = "cpu", + ) -> Tuple[AbsESPnetModel, argparse.Namespace]: + """Build model from the files. + + This method is used for inference or fine-tuning. + + Args: + config_file: The yaml file saved when training. + model_file: The model file saved when training. + device: Device type, "cpu", "cuda", or "cuda:N". + + """ + assert check_argument_types() + if config_file is None: + assert model_file is not None, ( + "The argument 'model_file' must be provided " + "if the argument 'config_file' is not specified." + ) + config_file = Path(model_file).parent / "config.yaml" + else: + config_file = Path(config_file) + + logging.info("config file: {}".format(config_file)) + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + args = argparse.Namespace(**args) + model = cls.build_model(args) + if not isinstance(model, AbsESPnetModel): + raise RuntimeError( + f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}" + ) + model.to(device) + + # For LoRA finetuned model, create LoRA adapter + use_lora = getattr(args, "use_lora", False) + if use_lora: + create_lora_adapter(model, **args.lora_conf) + + if model_file is not None: + if device == "cuda": + # NOTE(kamo): "cuda" for torch.load always indicates cuda:0 + # in PyTorch<=1.4 + device = f"cuda:{torch.cuda.current_device()}" + try: + model.load_state_dict( + torch.load(model_file, map_location=device), + strict=not use_lora, + ) + except RuntimeError: + # Note(simpleoier): the following part is to be compatible with + # pretrained model using earlier versions before `0a625088` + state_dict = torch.load(model_file, map_location=device) + if any(["frontend.upstream.model" in k for k in state_dict.keys()]): + if any( + [ + "frontend.upstream.upstream.model" in k + for k in dict(model.named_parameters()) + ] + ): + state_dict = { + k.replace( + "frontend.upstream.model", + "frontend.upstream.upstream.model", + ): v + for k, v in state_dict.items() + } + model.load_state_dict(state_dict, strict=not use_lora) + else: + raise + else: + raise + + return model, args diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/tasks/enh.py b/examples/speech_synthesis/preprocessing/tfgridnet/tasks/enh.py new file mode 100644 index 0000000000..5808f76c60 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/tasks/enh.py @@ -0,0 +1,458 @@ +import argparse +import copy +import os +from typing import Callable, Collection, Dict, List, Optional, Tuple + +import numpy as np +import torch +from typeguard import check_argument_types, check_return_type + +from examples.speech_synthesis.preprocessing.tfgridnet.mask import AbsMask, MultiMask +from examples.speech_synthesis.preprocessing.tfgridnet.enh.decoder import ( + AbsDecoder, + ConvDecoder, + STFTDecoder, + NullDecoder, +) +from examples.speech_synthesis.preprocessing.tfgridnet.enh.encoder import ( + AbsEncoder, + ConvEncoder, + STFTEncoder, + NullEncoder, +) +from examples.speech_synthesis.preprocessing.tfgridnet.train.espnet_model import ESPnetEnhancementModel +from examples.speech_synthesis.preprocessing.tfgridnet.enh.loss_criterion import AbsEnhLoss, SISNRLoss +from examples.speech_synthesis.preprocessing.tfgridnet.enh.wrappers import AbsLossWrapper, PITSolver, FixedOrderSolver +from examples.speech_synthesis.preprocessing.tfgridnet.enh.separator import ( + AbsSeparator, + RNNSeparator, + TCNSeparator, + TFGridNetMasking, +) + +from examples.speech_synthesis.preprocessing.tfgridnet.tasks.abs_task import AbsTask +from examples.speech_synthesis.preprocessing.tfgridnet.torch_utils.initialize import initialize +from examples.speech_synthesis.preprocessing.tfgridnet.train.class_choices import ClassChoices +from examples.speech_synthesis.preprocessing.tfgridnet.train.preprocessor import ( + AbsPreprocessor, + DynamicMixingPreprocessor, + EnhPreprocessor, +) + +encoder_choices = ClassChoices( + name="encoder", + classes=dict(stft=STFTEncoder, conv=ConvEncoder, same=NullEncoder), + type_check=AbsEncoder, + default="stft", +) + +separator_choices = ClassChoices( + name="separator", + classes=dict( + rnn=RNNSeparator, tcn=TCNSeparator, tfgridnet_masking=TFGridNetMasking + ), + type_check=AbsSeparator, + default="rnn", +) + +mask_module_choices = ClassChoices( + name="mask_module", + classes=dict(multi_mask=MultiMask), + type_check=AbsMask, + default="multi_mask", +) + +decoder_choices = ClassChoices( + name="decoder", + classes=dict(stft=STFTDecoder, conv=ConvDecoder, same=NullDecoder), + type_check=AbsDecoder, + default="stft", +) + +loss_wrapper_choices = ClassChoices( + name="loss_wrappers", + classes=dict( + pit=PITSolver, + fixed_order=FixedOrderSolver, + ), + type_check=AbsLossWrapper, + default="pit", +) + +criterion_choices = ClassChoices( + name="criterions", + classes=dict( + si_snr=SISNRLoss, + ), + type_check=AbsEnhLoss, + default="SISNRLoss", +) + +preprocessor_choices = ClassChoices( + name="preprocessor", + classes=dict( + dynamic_mixing=DynamicMixingPreprocessor, + enh=EnhPreprocessor, + ), + type_check=AbsPreprocessor, + default=None, +) + +MAX_REFERENCE_NUM = 100 + + +class EnhancementTask(AbsTask): + # If you need more than one optimizers, change this value + num_optimizers: int = 1 + + class_choices_list = [ + # --encoder and --encoder_conf + encoder_choices, + # --separator and --separator_conf + separator_choices, + # --decoder and --decoder_conf + decoder_choices, + # --mask_module and --mask_module_conf + mask_module_choices, + # --preprocessor and --preprocessor_conf + preprocessor_choices, + ] + + # If you need to modify train() or eval() procedures, change Trainer class here + # trainer = Trainer + + @classmethod + def add_task_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group(description="Task related") + + # NOTE(kamo): add_arguments(..., required=True) can't be used + # to provide --print_config mode. Instead of it, do as + # required = parser.get_default("required") + + group.add_argument( + "--init", + type=lambda x: str_or_none(x.lower()), + default=None, + help="The initialization method", + choices=[ + "chainer", + "xavier_uniform", + "xavier_normal", + "kaiming_uniform", + "kaiming_normal", + None, + ], + ) + + group.add_argument( + "--model_conf", + action=NestedDictAction, + default=get_default_kwargs(ESPnetEnhancementModel), + help="The keyword arguments for model class.", + ) + + group.add_argument( + "--criterions", + action=NestedDictAction, + default=[ + { + "name": "si_snr", + "conf": {}, + "wrapper": "fixed_order", + "wrapper_conf": {}, + }, + ], + help="The criterions binded with the loss wrappers.", + ) + + group = parser.add_argument_group(description="Preprocess related") + group.add_argument( + "--speech_volume_normalize", + type=str_or_none, + default=None, + help="Scale the maximum amplitude to the given value or range. " + "e.g. --speech_volume_normalize 1.0 scales it to 1.0.\n" + "--speech_volume_normalize 0.5_1.0 scales it to a random number in " + "the range [0.5, 1.0)", + ) + group.add_argument( + "--rir_scp", + type=str_or_none, + default=None, + help="The file path of rir scp file.", + ) + group.add_argument( + "--rir_apply_prob", + type=float, + default=1.0, + help="THe probability for applying RIR convolution.", + ) + group.add_argument( + "--noise_scp", + type=str_or_none, + default=None, + help="The file path of noise scp file.", + ) + group.add_argument( + "--noise_apply_prob", + type=float, + default=1.0, + help="The probability applying Noise adding.", + ) + group.add_argument( + "--noise_db_range", + type=str, + default="13_15", + help="The range of signal-to-noise ratio (SNR) level in decibel.", + ) + group.add_argument( + "--short_noise_thres", + type=float, + default=0.5, + help="If len(noise) / len(speech) is smaller than this threshold during " + "dynamic mixing, a warning will be displayed.", + ) + group.add_argument( + "--use_reverberant_ref", + type=str2bool, + default=False, + help="Whether to use reverberant speech references " + "instead of anechoic ones", + ) + group.add_argument( + "--num_spk", + type=int, + default=1, + help="Number of speakers in the input signal.", + ) + group.add_argument( + "--num_noise_type", + type=int, + default=1, + help="Number of noise types.", + ) + group.add_argument( + "--sample_rate", + type=int, + default=8000, + help="Sampling rate of the data (in Hz).", + ) + group.add_argument( + "--force_single_channel", + type=str2bool, + default=False, + help="Whether to force all data to be single-channel.", + ) + group.add_argument( + "--channel_reordering", + type=str2bool, + default=False, + help="Whether to randomly reorder the channels of the " + "multi-channel signals.", + ) + group.add_argument( + "--categories", + nargs="+", + default=[], + type=str, + help="The set of all possible categories in the dataset. Used to add the " + "category information to each sample", + ) + group.add_argument( + "--speech_segment", + type=int_or_none, + default=None, + help="Truncate the audios to the specified length (in samples) if not None", + ) + group.add_argument( + "--avoid_allzero_segment", + type=str2bool, + default=True, + help="Only used when --speech_segment is specified. If True, make sure " + "all truncated segments are not all-zero", + ) + group.add_argument( + "--flexible_numspk", + type=str2bool, + default=False, + help="Whether to load variable numbers of speakers in each sample. " + "In this case, only the first-speaker files such as 'spk1.scp' and " + "'dereverb1.scp' are used, which are expected to have multiple columns. " + "Other numbered files such as 'spk2.scp' and 'dereverb2.scp' are ignored.", + ) + + group.add_argument( + "--dynamic_mixing", + type=str2bool, + default=False, + help="Apply dynamic mixing", + ) + group.add_argument( + "--utt2spk", + type=str_or_none, + default=None, + help="The file path of utt2spk file. Only used in dynamic_mixing mode.", + ) + group.add_argument( + "--dynamic_mixing_gain_db", + type=float, + default=0.0, + help="Random gain (in dB) for dynamic mixing sources", + ) + + for class_choices in cls.class_choices_list: + # Append -- and --_conf. + # e.g. --encoder and --encoder_conf + class_choices.add_arguments(group) + + @classmethod + def build_collate_fn( + cls, args: argparse.Namespace, train: bool + ) -> Callable[ + [Collection[Tuple[str, Dict[str, np.ndarray]]]], + Tuple[List[str], Dict[str, torch.Tensor]], + ]: + assert check_argument_types() + + return CommonCollateFn(float_pad_value=0.0, int_pad_value=0) + + @classmethod + def build_preprocess_fn( + cls, args: argparse.Namespace, train: bool + ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: + assert check_argument_types() + + use_preprocessor = getattr(args, "preprocessor", None) is not None + + if use_preprocessor: + # TODO(simpleoier): To make this as simple as model parts, e.g. encoder + if args.preprocessor == "dynamic_mixing": + retval = preprocessor_choices.get_class(args.preprocessor)( + train=train, + source_scp=os.path.join( + os.path.dirname(args.train_data_path_and_name_and_type[0][0]), + args.preprocessor_conf.get("source_scp_name", "spk1.scp"), + ), + ref_num=args.preprocessor_conf.get( + "ref_num", args.separator_conf["num_spk"] + ), + dynamic_mixing_gain_db=args.preprocessor_conf.get( + "dynamic_mixing_gain_db", 0.0 + ), + speech_name=args.preprocessor_conf.get("speech_name", "speech_mix"), + speech_ref_name_prefix=args.preprocessor_conf.get( + "speech_ref_name_prefix", "speech_ref" + ), + mixture_source_name=args.preprocessor_conf.get( + "mixture_source_name", None + ), + utt2spk=getattr(args, "utt2spk", None), + categories=args.preprocessor_conf.get("categories", None), + ) + elif args.preprocessor == "enh": + kwargs = dict( + # NOTE(kamo): Check attribute existence for backward compatibility + rir_scp=getattr(args, "rir_scp", None), + rir_apply_prob=getattr(args, "rir_apply_prob", 1.0), + noise_scp=getattr(args, "noise_scp", None), + noise_apply_prob=getattr(args, "noise_apply_prob", 1.0), + noise_db_range=getattr(args, "noise_db_range", "13_15"), + short_noise_thres=getattr(args, "short_noise_thres", 0.5), + speech_volume_normalize=getattr( + args, "speech_volume_normalize", None + ), + use_reverberant_ref=getattr(args, "use_reverberant_ref", None), + num_spk=getattr(args, "num_spk", 1), + num_noise_type=getattr(args, "num_noise_type", 1), + sample_rate=getattr(args, "sample_rate", 8000), + force_single_channel=getattr(args, "force_single_channel", False), + channel_reordering=getattr(args, "channel_reordering", False), + categories=getattr(args, "categories", None), + speech_segment=getattr(args, "speech_segment", None), + avoid_allzero_segment=getattr(args, "avoid_allzero_segment", True), + flexible_numspk=getattr(args, "flexible_numspk", False), + ) + kwargs.update(args.preprocessor_conf) + retval = preprocessor_choices.get_class(args.preprocessor)( + train=train, **kwargs + ) + else: + raise ValueError( + f"Preprocessor type {args.preprocessor} is not supported." + ) + else: + retval = None + assert check_return_type(retval) + return retval + + @classmethod + def required_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + if not inference: + retval = ("speech_ref1",) + else: + # Inference mode + retval = ("speech_mix",) + return retval + + @classmethod + def optional_data_names( + cls, train: bool = True, inference: bool = False + ) -> Tuple[str, ...]: + retval = ["speech_mix"] + retval += ["dereverb_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] + retval += ["speech_ref{}".format(n) for n in range(2, MAX_REFERENCE_NUM + 1)] + retval += ["noise_ref{}".format(n) for n in range(1, MAX_REFERENCE_NUM + 1)] + retval += ["category"] + retval = tuple(retval) + assert check_return_type(retval) + return retval + + @classmethod + def build_model(cls, args: argparse.Namespace) -> ESPnetEnhancementModel: + assert check_argument_types() + + encoder = encoder_choices.get_class(args.encoder)(**args.encoder_conf) + separator = separator_choices.get_class(args.separator)( + encoder.output_dim, **args.separator_conf + ) + decoder = decoder_choices.get_class(args.decoder)(**args.decoder_conf) + if args.separator.endswith("nomask"): + mask_module = mask_module_choices.get_class(args.mask_module)( + input_dim=encoder.output_dim, + **args.mask_module_conf, + ) + else: + mask_module = None + + loss_wrappers = [] + + if getattr(args, "criterions", None) is not None: + # This check is for the compatibility when load models + # that packed by older version + for ctr in args.criterions: + criterion_conf = ctr.get("conf", {}) + criterion = criterion_choices.get_class(ctr["name"])(**criterion_conf) + loss_wrapper = loss_wrapper_choices.get_class(ctr["wrapper"])( + criterion=criterion, **ctr["wrapper_conf"] + ) + loss_wrappers.append(loss_wrapper) + + # 1. Build model + model = ESPnetEnhancementModel( + encoder=encoder, + separator=separator, + decoder=decoder, + loss_wrappers=loss_wrappers, + mask_module=mask_module, + **args.model_conf, + ) + + # FIXME(kamo): Should be done in model? + # 2. Initialize + if args.init is not None: + initialize(model, args.init) + + assert check_return_type(model) + return model diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/device_functions.py b/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/device_functions.py new file mode 100644 index 0000000000..99ad6bf6e5 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/device_functions.py @@ -0,0 +1,31 @@ +import dataclasses +import numpy as np +import torch + + +def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): + """Change the device of object recursively""" + if isinstance(data, dict): + return { + k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() + } + elif dataclasses.is_dataclass(data) and not isinstance(data, type): + return type(data)( + *[ + to_device(v, device, dtype, non_blocking, copy) + for v in dataclasses.astuple(data) + ] + ) + # maybe namedtuple. I don't know the correct way to judge namedtuple. + elif isinstance(data, tuple) and type(data) is not tuple: + return type(data)( + *[to_device(o, device, dtype, non_blocking, copy) for o in data] + ) + elif isinstance(data, (list, tuple)): + return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) + elif isinstance(data, np.ndarray): + return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) + elif isinstance(data, torch.Tensor): + return data.to(device, dtype, non_blocking, copy) + else: + return data diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/get_layer_from_string.py b/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/get_layer_from_string.py new file mode 100644 index 0000000000..aa07ffca5d --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/get_layer_from_string.py @@ -0,0 +1,43 @@ +import difflib + +import torch + + +def get_layer(l_name, library=torch.nn): + """Return layer object handler from library e.g. from torch.nn + + E.g. if l_name=="elu", returns torch.nn.ELU. + + Args: + l_name (string): Case insensitive name for layer in library (e.g. .'elu'). + library (module): Name of library/module where to search for object handler + with l_name e.g. "torch.nn". + + Returns: + layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) + + """ + + all_torch_layers = [x for x in dir(torch.nn)] + match = [x for x in all_torch_layers if l_name.lower() == x.lower()] + if len(match) == 0: + close_matches = difflib.get_close_matches( + l_name, [x.lower() for x in all_torch_layers] + ) + raise NotImplementedError( + "Layer with name {} not found in {}.\n Closest matches: {}".format( + l_name, str(library), close_matches + ) + ) + elif len(match) > 1: + close_matches = difflib.get_close_matches( + l_name, [x.lower() for x in all_torch_layers] + ) + raise NotImplementedError( + "Multiple matchs for layer with name {} not found in {}.\n " + "All matches: {}".format(l_name, str(library), close_matches) + ) + else: + # valid + layer_handler = getattr(library, match[0]) + return layer_handler diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/initialize.py b/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/initialize.py new file mode 100644 index 0000000000..e271132f36 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/torch_utils/initialize.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 + +"""Initialize modules for espnet2 neural networks.""" + +import logging +import math + +import torch +from typeguard import check_argument_types + + +def initialize(model: torch.nn.Module, init: str): + """Initialize weights of a neural network module. + + Parameters are initialized using the given method or distribution. + + Custom initialization routines can be implemented into submodules + as function `espnet_initialization_fn` within the custom module. + + Args: + model: Target. + init: Method of initialization. + """ + assert check_argument_types() + + if init == "chainer": + # 1. lecun_normal_init_parameters + for name, p in model.named_parameters(): + data = p.data + if ".bias" in name and data.dim() == 1: + # bias + data.zero_() + logging.info(f"Initialize {name} to zeros") + elif data.dim() == 1: + # linear weight + n = data.size(0) + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + elif data.dim() == 2: + # linear weight + n = data.size(1) + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + elif data.dim() in (3, 4): + # conv weight + n = data.size(1) + for k in data.size()[2:]: + n *= k + stdv = 1.0 / math.sqrt(n) + data.normal_(0, stdv) + else: + raise NotImplementedError + + for mod in model.modules(): + # 2. embed weight ~ Normal(0, 1) + if isinstance(mod, torch.nn.Embedding): + mod.weight.data.normal_(0, 1) + # 3. forget-bias = 1.0 + elif isinstance(mod, torch.nn.RNNCellBase): + n = mod.bias_ih.size(0) + mod.bias_ih.data[n // 4 : n // 2].fill_(1.0) + elif isinstance(mod, torch.nn.RNNBase): + for name, param in mod.named_parameters(): + if "bias" in name: + n = param.size(0) + param.data[n // 4 : n // 2].fill_(1.0) + if hasattr(mod, "espnet_initialization_fn"): + mod.espnet_initialization_fn() + + else: + # weight init + for p in model.parameters(): + if p.dim() > 1: + if init == "xavier_uniform": + torch.nn.init.xavier_uniform_(p.data) + elif init == "xavier_normal": + torch.nn.init.xavier_normal_(p.data) + elif init == "kaiming_uniform": + torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu") + elif init == "kaiming_normal": + torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu") + else: + raise ValueError("Unknown initialization: " + init) + # bias init + for name, p in model.named_parameters(): + if ".bias" in name and p.dim() == 1: + p.data.zero_() + logging.info(f"Initialize {name} to zeros") + + # reset some modules with default init + for m in model.modules(): + if isinstance( + m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm) + ): + m.reset_parameters() + if hasattr(m, "espnet_initialization_fn"): + m.espnet_initialization_fn() + + # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization + if getattr(model, "encoder", None) and getattr( + model.encoder, "reload_pretrained_parameters", None + ): + model.encoder.reload_pretrained_parameters() + if getattr(model, "frontend", None): + if getattr(model.frontend, "reload_pretrained_parameters", None): + model.frontend.reload_pretrained_parameters() + elif isinstance( + getattr(model.frontend, "frontends", None), + torch.nn.ModuleList, + ): + for i, _ in enumerate(getattr(model.frontend, "frontends")): + if getattr( + model.frontend.frontends[i], + "reload_pretrained_parameters", + None, + ): + model.frontend.frontends[i].reload_pretrained_parameters() + if getattr(model, "postencoder", None) and getattr( + model.postencoder, "reload_pretrained_parameters", None + ): + model.postencoder.reload_pretrained_parameters() + if getattr(model, "decoder", None) and getattr( + model.decoder, "reload_pretrained_parameters", None + ): + model.decoder.reload_pretrained_parameters() diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/train/class_choices.py b/examples/speech_synthesis/preprocessing/tfgridnet/train/class_choices.py new file mode 100644 index 0000000000..3d51ee5024 --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/train/class_choices.py @@ -0,0 +1,90 @@ +from typing import Mapping, Optional, Tuple + +from typeguard import check_argument_types, check_return_type + + +class ClassChoices: + """Helper class to manage the options for variable objects and its configuration. + + Example: + + >>> class A: + ... def __init__(self, foo=3): pass + >>> class B: + ... def __init__(self, bar="aaaa"): pass + >>> choices = ClassChoices("var", dict(a=A, b=B), default="a") + >>> import argparse + >>> parser = argparse.ArgumentParser() + >>> choices.add_arguments(parser) + >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4") + >>> args.var + a + >>> args.var_conf + {"foo": 4} + >>> class_obj = choices.get_class(args.var) + >>> a_object = class_obj(**args.var_conf) + + """ + + def __init__( + self, + name: str, + classes: Mapping[str, type], + type_check: type = None, + default: str = None, + optional: bool = False, + ): + assert check_argument_types() + self.name = name + self.base_type = type_check + self.classes = {k.lower(): v for k, v in classes.items()} + if "none" in self.classes or "nil" in self.classes or "null" in self.classes: + raise ValueError('"none", "nil", and "null" are reserved.') + if type_check is not None: + for v in self.classes.values(): + if not issubclass(v, type_check): + raise ValueError(f"must be {type_check.__name__}, but got {v}") + + self.optional = optional + self.default = default + if default is None: + self.optional = True + + def choices(self) -> Tuple[Optional[str], ...]: + retval = tuple(self.classes) + if self.optional: + return retval + (None,) + else: + return retval + + def get_class(self, name: Optional[str]) -> Optional[type]: + assert check_argument_types() + print(f"--{self.name} must be one of {self.choices()}: ") + if name is None or (self.optional and name.lower() == ("none", "null", "nil")): + retval = None + elif name.lower() in self.classes: + class_obj = self.classes[name] + assert check_return_type(class_obj) + retval = class_obj + else: + raise ValueError( + f"--{self.name} must be one of {self.choices()}: " + f"--{self.name} {name.lower()}" + ) + + return retval + + def add_arguments(self, parser): + parser.add_argument( + f"--{self.name}", + type=lambda x: str_or_none(x.lower()), + default=self.default, + choices=self.choices(), + help=f"The {self.name} type", + ) + parser.add_argument( + f"--{self.name}_conf", + action=NestedDictAction, + default=dict(), + help=f"The keyword arguments for {self.name}", + ) diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/train/espnet_model.py b/examples/speech_synthesis/preprocessing/tfgridnet/train/espnet_model.py new file mode 100644 index 0000000000..22ed07540f --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/train/espnet_model.py @@ -0,0 +1,282 @@ +"""Enhancement model module.""" +from typing import Dict, List, Optional, OrderedDict, Tuple +from abc import ABC, abstractmethod + +import numpy as np +import torch +from typeguard import check_argument_types + +from examples.speech_synthesis.preprocessing.tfgridnet.mask import AbsMask +from examples.speech_synthesis.preprocessing.tfgridnet.enh.decoder import AbsDecoder +from examples.speech_synthesis.preprocessing.tfgridnet.enh.encoder import AbsEncoder +from examples.speech_synthesis.preprocessing.tfgridnet.enh.wrappers import AbsLossWrapper +from examples.speech_synthesis.preprocessing.tfgridnet.enh.separator import AbsSeparator + +#from packaging.version import parse as V +#is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0") + +class AbsESPnetModel(torch.nn.Module, ABC): + + @abstractmethod + def forward( + self, **batch: torch.Tensor + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + raise NotImplementedError + + @abstractmethod + def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: + raise NotImplementedError + + +class ESPnetEnhancementModel(AbsESPnetModel): + """Speech enhancement or separation Frontend model""" + + def __init__( + self, + encoder: AbsEncoder, + separator: AbsSeparator, + decoder: AbsDecoder, + mask_module: Optional[AbsMask], + loss_wrappers: List[AbsLossWrapper], + stft_consistency: bool = False, + loss_type: str = "mask_mse", + mask_type: Optional[str] = None, + flexible_numspk: bool = False, + extract_feats_in_collect_stats: bool = False, + normalize_variance: bool = False, + normalize_variance_per_ch: bool = False, + categories: list = [], + category_weights: list = [], + ): + + assert check_argument_types() + + super().__init__() + + self.encoder = encoder + self.separator = separator + self.decoder = decoder + self.mask_module = mask_module + self.num_spk = separator.num_spk + # If True, self.num_spk is regarded as the MAXIMUM possible number of speakers + self.flexible_numspk = flexible_numspk + self.num_noise_type = getattr(self.separator, "num_noise_type", 1) + + self.loss_wrappers = loss_wrappers + names = [w.criterion.name for w in self.loss_wrappers] + if len(set(names)) != len(names): + raise ValueError("Duplicated loss names are not allowed: {}".format(names)) + + # kept for compatibility + self.mask_type = mask_type.upper() if mask_type else None + self.loss_type = loss_type + self.stft_consistency = stft_consistency + + # for multi-channel signal + self.ref_channel = getattr(self.separator, "ref_channel", None) + if self.ref_channel is None: + self.ref_channel = 0 + + self.extract_feats_in_collect_stats = extract_feats_in_collect_stats + + self.normalize_variance = normalize_variance + self.normalize_variance_per_ch = normalize_variance_per_ch + if normalize_variance and normalize_variance_per_ch: + raise ValueError( + "normalize_variance and normalize_variance_per_ch cannot be True " + "at the same time." + ) + + # list all possible categories of the batch (order matters!) + # (used to convert category index to the corresponding name for logging) + self.categories = {} + if categories: + count = 0 + for c in categories: + if c not in self.categories: + self.categories[count] = c + count += 1 + # used to set loss weights for batches of different categories + if category_weights: + assert len(category_weights) == len(self.categories) + self.category_weights = tuple(category_weights) + else: + self.category_weights = tuple(1.0 for _ in self.categories) + + def forward( + self, + speech_mix: torch.Tensor, + speech_mix_lengths: torch.Tensor = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + """Frontend + Encoder + Decoder + Calc loss + + Args: + speech_mix: (Batch, samples) or (Batch, samples, channels) + speech_ref: (Batch, num_speaker, samples) + or (Batch, num_speaker, samples, channels) + speech_mix_lengths: (Batch,), default None for chunk interator, + because the chunk-iterator does not have the + speech_lengths returned. see in + espnet2/iterators/chunk_iter_factory.py + kwargs: "utt_id" is among the input. + """ + # reference speech signal of each speaker + assert "speech_ref1" in kwargs, "At least 1 reference signal input is required." + speech_ref = [ + kwargs.get( + f"speech_ref{spk + 1}", + torch.zeros_like(kwargs["speech_ref1"]), + ) + for spk in range(self.num_spk) + if f"speech_ref{spk + 1}" in kwargs + ] + num_spk = len(speech_ref) if self.flexible_numspk else self.num_spk + assert len(speech_ref) == num_spk, (len(speech_ref), num_spk) + # (Batch, num_speaker, samples) or (Batch, num_speaker, samples, channels) + speech_ref = torch.stack(speech_ref, dim=1) + + if "noise_ref1" in kwargs: + # noise signal (optional, required when using beamforming-based + # frontend models) + noise_ref = [ + kwargs["noise_ref{}".format(n + 1)] for n in range(self.num_noise_type) + ] + # (Batch, num_noise_type, samples) or + # (Batch, num_noise_type, samples, channels) + noise_ref = torch.stack(noise_ref, dim=1) + else: + noise_ref = None + + # dereverberated (noisy) signal + # (optional, only used for frontend models with WPE) + if "dereverb_ref1" in kwargs: + # noise signal (optional, required when using + # frontend models with beamformering) + dereverb_speech_ref = [ + kwargs["dereverb_ref{}".format(n + 1)] + for n in range(num_spk) + if "dereverb_ref{}".format(n + 1) in kwargs + ] + assert len(dereverb_speech_ref) in (1, num_spk), len(dereverb_speech_ref) + # (Batch, N, samples) or (Batch, N, samples, channels) + dereverb_speech_ref = torch.stack(dereverb_speech_ref, dim=1) + else: + dereverb_speech_ref = None + + batch_size = speech_mix.shape[0] + speech_lengths = ( + speech_mix_lengths + if speech_mix_lengths is not None + else torch.ones(batch_size).int().fill_(speech_mix.shape[1]) + ) + assert speech_lengths.dim() == 1, speech_lengths.shape + # Check that batch_size is unified + assert speech_mix.shape[0] == speech_ref.shape[0] == speech_lengths.shape[0], ( + speech_mix.shape, + speech_ref.shape, + speech_lengths.shape, + ) + + # for data-parallel + speech_ref = speech_ref[..., : speech_lengths.max()].unbind(dim=1) + if noise_ref is not None: + noise_ref = noise_ref[..., : speech_lengths.max()].unbind(dim=1) + if dereverb_speech_ref is not None: + dereverb_speech_ref = dereverb_speech_ref[..., : speech_lengths.max()] + dereverb_speech_ref = dereverb_speech_ref.unbind(dim=1) + + # sampling frequency information about the batch + fs = None + if "utt2fs" in kwargs: + # All samples must have the same sampling rate + fs = kwargs["utt2fs"][0].item() + assert all([fs == kwargs["utt2fs"][0].item() for fs in kwargs["utt2fs"]]) + + # Adaptively adjust the STFT/iSTFT window/hop sizes for USESSeparator + if not isinstance(self.separator, USESSeparator): + fs = None + + # category information (integer) about the batch + category = kwargs.get("utt2category", None) + if ( + self.categories + and category is not None + and category[0].item() not in self.categories + ): + raise ValueError(f"Category '{category}' is not listed in self.categories") + + additional = {} + # Additional data is required in Deep Attractor Network + if isinstance(self.separator, DANSeparator): + additional["feature_ref"] = [ + self.encoder(r, speech_lengths, fs=fs)[0] for r in speech_ref + ] + if self.flexible_numspk: + additional["num_spk"] = num_spk + # Additional information is required in USES for multi-condition training + if category is not None and isinstance(self.separator, USESSeparator): + cat = self.categories[category[0].item()] + if cat.endswith("_both"): + additional["mode"] = "both" + elif cat.endswith("_reverb"): + additional["mode"] = "dereverb" + else: + additional["mode"] = "no_dereverb" + + speech_mix = speech_mix[:, : speech_lengths.max()] + + ################################### + # Normalize the signal variance + if self.normalize_variance_per_ch: + dim = 1 + mix_std_ = torch.std(speech_mix, dim=dim, keepdim=True) + speech_mix = speech_mix / mix_std_ # RMS normalization + elif self.normalize_variance: + if speech_mix.ndim > 2: + dim = (1, 2) + else: + dim = 1 + mix_std_ = torch.std(speech_mix, dim=dim, keepdim=True) + speech_mix = speech_mix / mix_std_ # RMS normalization + + # model forward + speech_pre, feature_mix, feature_pre, others = self.forward_enhance( + speech_mix, speech_lengths, additional, fs=fs + ) + + ################################### + # De-normalize the signal variance + if self.normalize_variance_per_ch and speech_pre is not None: + if mix_std_.ndim > 2: + mix_std_ = mix_std_[:, :, self.ref_channel] + speech_pre = [sp * mix_std_ for sp in speech_pre] + elif self.normalize_variance and speech_pre is not None: + if mix_std_.ndim > 2: + mix_std_ = mix_std_.squeeze(2) + speech_pre = [sp * mix_std_ for sp in speech_pre] + + # loss computation + loss, stats, weight, perm = self.forward_loss( + speech_pre, + speech_lengths, + feature_mix, + feature_pre, + others, + speech_ref, + noise_ref, + dereverb_speech_ref, + category, + num_spk=num_spk, + fs=fs, + ) + return loss, stats, weight + + def collect_feats( + self, speech_mix: torch.Tensor, speech_mix_lengths: torch.Tensor, **kwargs + ) -> Dict[str, torch.Tensor]: + # for data-parallel + speech_mix = speech_mix[:, : speech_mix_lengths.max()] + + feats, feats_lengths = speech_mix, speech_mix_lengths + return {"feats": feats, "feats_lengths": feats_lengths} diff --git a/examples/speech_synthesis/preprocessing/tfgridnet/train/preprocessor.py b/examples/speech_synthesis/preprocessing/tfgridnet/train/preprocessor.py new file mode 100644 index 0000000000..37622b47ef --- /dev/null +++ b/examples/speech_synthesis/preprocessing/tfgridnet/train/preprocessor.py @@ -0,0 +1,1096 @@ +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Collection, Dict, Iterable, List, Optional, Tuple, Union +import numpy as np + +class AbsPreprocessor(ABC): + def __init__(self, train: bool): + self.train = train + + @abstractmethod + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + raise NotImplementedError + + +def framing( + x, + frame_length: int = 512, + frame_shift: int = 256, + centered: bool = True, + padded: bool = True, +): + if x.size == 0: + raise ValueError("Input array size is zero") + if frame_length < 1: + raise ValueError("frame_length must be a positive integer") + if frame_length > x.shape[-1]: + raise ValueError("frame_length is greater than input length") + if 0 >= frame_shift: + raise ValueError("frame_shift must be greater than 0") + + if centered: + pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [ + (frame_length // 2, frame_length // 2) + ] + x = np.pad(x, pad_shape, mode="constant", constant_values=0) + + if padded: + # Pad to integer number of windowed segments + # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep, + # with integer nseg + nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length + pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)] + x = np.pad(x, pad_shape, mode="constant", constant_values=0) + + # Created strided array of data segments + if frame_length == 1 and frame_length == frame_shift: + result = x[..., None] + else: + shape = x.shape[:-1] + ( + (x.shape[-1] - frame_length) // frame_shift + 1, + frame_length, + ) + strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1]) + result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return result + + +def detect_non_silence( + x: np.ndarray, + threshold: float = 0.01, + frame_length: int = 1024, + frame_shift: int = 512, + window: str = "boxcar", +) -> np.ndarray: + """Power based voice activity detection. + + Args: + x: (Channel, Time) + >>> x = np.random.randn(1000) + >>> detect = detect_non_silence(x) + >>> assert x.shape == detect.shape + >>> assert detect.dtype == np.bool + """ + if x.shape[-1] < frame_length: + return np.full(x.shape, fill_value=True, dtype=np.bool) + + if x.dtype.kind == "i": + x = x.astype(np.float64) + # framed_w: (C, T, F) + framed_w = framing( + x, + frame_length=frame_length, + frame_shift=frame_shift, + centered=False, + padded=True, + ) + framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype) + # power: (C, T) + power = (framed_w**2).mean(axis=-1) + # mean_power: (C, 1) + mean_power = np.mean(power, axis=-1, keepdims=True) + if np.all(mean_power == 0): + return np.full(x.shape, fill_value=True, dtype=np.bool) + # detect_frames: (C, T) + detect_frames = power / mean_power > threshold + # detects: (C, T, F) + detects = np.broadcast_to( + detect_frames[..., None], detect_frames.shape + (frame_shift,) + ) + # detects: (C, TF) + detects = detects.reshape(*detect_frames.shape[:-1], -1) + # detects: (C, TF) + return np.pad( + detects, + [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])], + mode="edge", + ) + + +def any_allzero(signal): + if isinstance(signal, (list, tuple)): + return any([np.allclose(s, 0.0) for s in signal]) + return np.allclose(signal, 0.0) + +class CommonPreprocessor(AbsPreprocessor): + def __init__( + self, + train: bool, + use_lang_prompt: bool = False, + use_nlp_prompt: bool = False, + token_type: str = None, + token_list: Union[Path, str, Iterable[str]] = None, + bpemodel: Union[Path, str, Iterable[str]] = None, + text_cleaner: Collection[str] = None, + g2p_type: str = None, + unk_symbol: str = "", + space_symbol: str = "", + non_linguistic_symbols: Union[Path, str, Iterable[str]] = None, + delimiter: str = None, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + short_noise_thres: float = 0.5, + aux_task_names: Collection[str] = None, + speech_volume_normalize: float = None, + speech_name: str = "speech", + text_name: str = "text", + fs: int = 0, + nonsplit_symbol: Iterable[str] = None, + data_aug_effects: List = None, + data_aug_num: List[int] = [1, 1], + data_aug_prob: float = 0.0, + # only use for whisper + whisper_language: str = None, + whisper_task: str = None, + ): + super().__init__(train) + self.train = train + self.speech_name = speech_name + self.text_name = text_name + self.speech_volume_normalize = speech_volume_normalize + self.rir_apply_prob = rir_apply_prob + self.noise_apply_prob = noise_apply_prob + self.short_noise_thres = short_noise_thres + self.aux_task_names = aux_task_names + self.use_lang_prompt = use_lang_prompt + self.use_nlp_prompt = use_nlp_prompt + + if token_type is not None: + if token_list is None: + raise ValueError("token_list is required if token_type is not None") + self.text_cleaner = TextCleaner(text_cleaner) + + self.tokenizer = build_tokenizer( + token_type=token_type, + bpemodel=bpemodel, + delimiter=delimiter, + space_symbol=space_symbol, + non_linguistic_symbols=non_linguistic_symbols, + g2p_type=g2p_type, + nonsplit_symbol=nonsplit_symbol, + whisper_language=whisper_language, + whisper_task=whisper_task, + ) + if token_type == "hugging_face": + self.token_id_converter = HuggingFaceTokenIDConverter( + model_name_or_path=bpemodel + ) + elif bpemodel not in ["whisper_en", "whisper_multilingual"]: + self.token_id_converter = TokenIDConverter( + token_list=token_list, + unk_symbol=unk_symbol, + ) + else: + self.token_id_converter = OpenAIWhisperTokenIDConverter( + model_type=bpemodel, + added_tokens_txt=non_linguistic_symbols, + language=whisper_language or "en", + task=whisper_task or "transcribe", + ) + else: + self.text_cleaner = None + self.tokenizer = None + self.token_id_converter = None + + if train and rir_scp is not None: + self.rirs = [] + rir_scp = [rir_scp] if not isinstance(rir_scp, (list, tuple)) else rir_scp + for scp in rir_scp: + with open(scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + if len(sps) == 1: + self.rirs.append(sps[0]) + else: + self.rirs.append(sps[1]) + else: + self.rirs = None + + if train and noise_scp is not None: + self.noises = [] + noise_scp = ( + [noise_scp] if not isinstance(noise_scp, (list, tuple)) else noise_scp + ) + for scp in noise_scp: + with open(scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + if len(sps) == 1: + self.noises.append(sps[0]) + else: + self.noises.append(sps[1]) + sps = noise_db_range.split("_") + if len(sps) == 1: + self.noise_db_low = self.noise_db_high = float(sps[0]) + elif len(sps) == 2: + self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1]) + else: + raise ValueError( + "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]" + ) + else: + self.noises = None + + # Check DataAugmentation docstring for more information of `data_aug_effects` + self.fs = fs + if data_aug_effects is not None: + assert self.fs > 0, self.fs + self.data_aug = DataAugmentation(data_aug_effects, apply_n=data_aug_num) + else: + self.data_aug = None + self.data_aug_prob = data_aug_prob + + def _convolve_rir(self, speech, power, rirs, tgt_fs=None, single_channel=False): + rir_path = np.random.choice(rirs) + rir = None + if rir_path is not None: + rir, fs = soundfile.read(rir_path, dtype=np.float64, always_2d=True) + + if single_channel: + num_ch = rir.shape[1] + chs = [np.random.randint(num_ch)] + rir = rir[:, chs] + # rir: (Nmic, Time) + rir = rir.T + if tgt_fs and fs != tgt_fs: + logging.warning( + f"Resampling RIR to match the sampling rate ({fs} -> {tgt_fs} Hz)" + ) + rir = librosa.resample( + rir, orig_sr=fs, target_sr=tgt_fs, res_type="kaiser_fast" + ) + + # speech: (Nmic, Time) + speech = speech[:1] + # Note that this operation doesn't change the signal length + speech = scipy.signal.convolve(speech, rir, mode="full")[ + :, : speech.shape[1] + ] + # Reverse mean power to the original power + power2 = (speech[detect_non_silence(speech)] ** 2).mean() + speech = np.sqrt(power / max(power2, 1e-10)) * speech + return speech, rir + + def _add_noise( + self, + speech, + power, + noises, + noise_db_low, + noise_db_high, + tgt_fs=None, + single_channel=False, + ): + nsamples = speech.shape[1] + noise_path = np.random.choice(noises) + noise = None + if noise_path is not None: + noise_db = np.random.uniform(noise_db_low, noise_db_high) + with soundfile.SoundFile(noise_path) as f: + fs = f.samplerate + if tgt_fs and fs != tgt_fs: + nsamples_ = int(nsamples / tgt_fs * fs) + 1 + else: + nsamples_ = nsamples + if f.frames == nsamples_: + noise = f.read(dtype=np.float64, always_2d=True) + elif f.frames < nsamples_: + if f.frames / nsamples_ < self.short_noise_thres: + logging.warning( + f"Noise ({f.frames}) is much shorter than " + f"speech ({nsamples_}) in dynamic mixing" + ) + offset = np.random.randint(0, nsamples_ - f.frames) + # noise: (Time, Nmic) + noise = f.read(dtype=np.float64, always_2d=True) + # Repeat noise + noise = np.pad( + noise, + [(offset, nsamples_ - f.frames - offset), (0, 0)], + mode="wrap", + ) + else: + offset = np.random.randint(0, f.frames - nsamples_) + f.seek(offset) + # noise: (Time, Nmic) + noise = f.read(nsamples_, dtype=np.float64, always_2d=True) + if len(noise) != nsamples_: + raise RuntimeError(f"Something wrong: {noise_path}") + if single_channel: + num_ch = noise.shape[1] + chs = [np.random.randint(num_ch)] + noise = noise[:, chs] + # noise: (Nmic, Time) + noise = noise.T + if tgt_fs and fs != tgt_fs: + logging.warning( + f"Resampling noise to match the sampling rate ({fs} -> {tgt_fs} Hz)" + ) + noise = librosa.resample( + noise, orig_sr=fs, target_sr=tgt_fs, res_type="kaiser_fast" + ) + if noise.shape[1] < nsamples: + noise = np.pad( + noise, [(0, 0), (0, nsamples - noise.shape[1])], mode="wrap" + ) + else: + noise = noise[:, :nsamples] + + noise_power = (noise**2).mean() + scale = ( + 10 ** (-noise_db / 20) + * np.sqrt(power) + / np.sqrt(max(noise_power, 1e-10)) + ) + speech = speech + scale * noise + return speech, noise + + def _speech_process( + self, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, Union[str, np.ndarray]]: + assert check_argument_types() + if self.speech_name in data: + if self.train and (self.rirs is not None or self.noises is not None): + speech = data[self.speech_name] + + # speech: (Nmic, Time) + if speech.ndim == 1: + speech = speech[None, :] + else: + speech = speech.T + # Calc power on non silence region + power = (speech[detect_non_silence(speech)] ** 2).mean() + + # 1. Convolve RIR + if self.rirs is not None and self.rir_apply_prob >= np.random.random(): + speech, _ = self._convolve_rir(speech, power, self.rirs) + + # 2. Add Noise + if ( + self.noises is not None + and self.noise_apply_prob >= np.random.random() + ): + speech, _ = self._add_noise( + speech, + power, + self.noises, + self.noise_db_low, + self.noise_db_high, + ) + + speech = speech.T + ma = np.max(np.abs(speech)) + if ma > 1.0: + speech /= ma + data[self.speech_name] = speech + + if self.train and self.data_aug: + if self.data_aug_prob > 0 and self.data_aug_prob >= np.random.random(): + data[self.speech_name] = self.data_aug( + data[self.speech_name], self.fs + ) + + if self.speech_volume_normalize is not None: + speech = data[self.speech_name] + ma = np.max(np.abs(speech)) + data[self.speech_name] = speech * self.speech_volume_normalize / ma + assert check_return_type(data) + return data + + def _text_process( + self, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + if self.text_name in data and self.tokenizer is not None: + text = data[self.text_name] + if isinstance(text, np.ndarray): + return data + text = self.text_cleaner(text) + tokens = self.tokenizer.text2tokens(text) + text_ints = self.token_id_converter.tokens2ids(tokens) + if len(text_ints) > 500: + logging.warning( + "The length of the text output exceeds 500, " + "which may cause OOM on the GPU." + "Please ensure that the data processing is correct and verify it." + ) + if "prompt" in data: + actual_token = ( + self.token_id_converter.tokenizer.tokenizer.convert_ids_to_tokens( + text_ints + ) + ) + if self.use_lang_prompt: + if data["prompt"] == "<|nospeech|>": + actual_token = [data["prompt"]] + else: + actual_token = data["prompt"].split() + actual_token[2:] + elif self.use_nlp_prompt: + prompt_tokens = self.tokenizer.text2tokens(data["prompt"]) + actual_token = [actual_token[0]] + prompt_tokens + actual_token[2:] + else: + if len(data["prompt"].split()) > 1: + actual_token = ( + [actual_token[0]] + + data["prompt"].split() + + actual_token[2:] + ) + else: + actual_token[1] = data["prompt"] + text_ints = ( + self.token_id_converter.tokenizer.tokenizer.convert_tokens_to_ids( + actual_token + ) + ) + data[self.text_name] = np.array(text_ints, dtype=np.int64) + if "prompt" in data: + whisper_tokenizer = self.token_id_converter.tokenizer.tokenizer + if len(data["prompt"].split()) > 1: + data["prompt"] = np.array( + whisper_tokenizer.convert_tokens_to_ids(data["prompt"].split()), + dtype=np.int64, + ) + else: + data["prompt"] = np.array( + [whisper_tokenizer.convert_tokens_to_ids(data["prompt"])], + dtype=np.int64, + ) + if self.aux_task_names is not None and self.tokenizer is not None: + for name in self.aux_task_names: + if name in data: + text = data[name] + text = self.text_cleaner(text) + tokens = self.tokenizer.text2tokens(text) + text_ints = self.token_id_converter.tokens2ids(tokens) + data[name] = np.array(text_ints, dtype=np.int64) + assert check_return_type(data) + return data + + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + assert check_argument_types() + + data = self._speech_process(data) + data = self._text_process(data) + return data + +class DynamicMixingPreprocessor(AbsPreprocessor): + def __init__( + self, + train: bool, + source_scp: str = None, + ref_num: int = 2, + dynamic_mixing_gain_db: float = 0.0, + speech_name: str = "speech_mix", + speech_ref_name_prefix: str = "speech_ref", + mixture_source_name: str = None, + utt2spk: str = None, + categories: Optional[List] = None, + ): + super().__init__(train) + self.source_scp = source_scp + self.ref_num = ref_num + self.dynamic_mixing_gain_db = dynamic_mixing_gain_db + self.speech_name = speech_name + self.speech_ref_name_prefix = speech_ref_name_prefix + # mixture_source_name: the key to select source utterances from dataloader + if mixture_source_name is None: + self.mixture_source_name = f"{speech_ref_name_prefix}1" + else: + self.mixture_source_name = mixture_source_name + + self.sources = {} + assert ( + source_scp is not None + ), f"Please pass `source_scp` to {type(self).__name__}" + with open(source_scp, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + assert len(sps) == 2 + self.sources[sps[0]] = sps[1] + + self.utt2spk = {} + if utt2spk is None: + # if utt2spk is not provided, create a dummy utt2spk with uid. + for key in self.sources.keys(): + self.utt2spk[key] = key + else: + with open(utt2spk, "r", encoding="utf-8") as f: + for line in f: + sps = line.strip().split(None, 1) + assert len(sps) == 2 + self.utt2spk[sps[0]] = sps[1] + + for key in self.sources.keys(): + assert key in self.utt2spk + + self.source_keys = list(self.sources.keys()) + + # Map each category into a unique integer + self.categories = {} + if categories: + count = 0 + for c in categories: + if c not in self.categories: + self.categories[c] = count + count += 1 + + def _pick_source_utterances_(self, uid): + # return (ref_num - 1) uid of reference sources. + + source_keys = [uid] + + spk_ids = [self.utt2spk[uid]] + + retry_cnt = 0 + while len(source_keys) < self.ref_num: + picked = random.choice(self.source_keys) + spk_id = self.utt2spk[picked] + + # make one utterance or one speaker only appears once in mixing. + if (picked not in source_keys) and (spk_id not in spk_ids): + source_keys.append(picked) + else: + retry_cnt += 1 + if retry_cnt > 10: + source_keys.append(picked) + logging.warning( + "Can not find speech source from different speaker " + f"for {retry_cnt} times." + "There may be problems with training data. " + "Please check the utt2spk file." + ) + + return source_keys[1:] + + def _read_source_(self, key, speech_length): + source, _ = soundfile.read( + self.sources[key], + dtype=np.float32, + always_2d=False, + ) + + if speech_length > source.shape[0]: + pad = speech_length - source.shape[0] + source = np.pad(source, (0, pad), "reflect") + else: + source = source[0:speech_length] + + assert speech_length == source.shape[0] + + return source + + def _mix_speech_(self, uid, data): + # pick sources + source_keys = self._pick_source_utterances_(uid) + + # load audios + speech_length = data[self.mixture_source_name].shape[0] + ref_audios = [self._read_source_(key, speech_length) for key in source_keys] + ref_audios = [data[self.mixture_source_name]] + ref_audios + + # apply random gain to speech sources + + gain_in_db = [ + random.uniform(-self.dynamic_mixing_gain_db, self.dynamic_mixing_gain_db) + for i in range(len(ref_audios)) + ] + gain = [10 ** (g_db / 20.0) for g_db in gain_in_db] + + ref_audios = [ref * g for ref, g in zip(ref_audios, gain)] + + speech_mix = np.sum(np.array(ref_audios), axis=0) + + for i, ref in enumerate(ref_audios): + data[f"{self.speech_ref_name_prefix}{i+1}"] = ref + data[self.speech_name] = speech_mix + + return data + + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + # TODO(Chenda): need to test for multi-channel data. + assert ( + len(data[self.mixture_source_name].shape) == 1 + ), "Multi-channel input has not been tested" + + # Add the category information (an integer) to `data` + if not self.categories and "category" in data: + raise ValueError( + "categories must be set in the config file when utt2category files " + "exist in the data directory (e.g., dump/raw/*/utt2category)" + ) + if self.categories and "category" in data: + category = data.pop("category") + assert category in self.categories, category + data["utt2category"] = np.array([self.categories[category]]) + + if self.train: + data = self._mix_speech_(uid, data) + + assert check_return_type(data) + return data + + +class EnhPreprocessor(CommonPreprocessor): + """Preprocessor for Speech Enhancement (Enh) task.""" + + def __init__( + self, + train: bool, + rir_scp: str = None, + rir_apply_prob: float = 1.0, + noise_scp: str = None, + noise_apply_prob: float = 1.0, + noise_db_range: str = "3_10", + short_noise_thres: float = 0.5, + speech_volume_normalize: float = None, + speech_name: str = "speech_mix", + speech_ref_name_prefix: str = "speech_ref", + noise_ref_name_prefix: str = "noise_ref", + dereverb_ref_name_prefix: str = "dereverb_ref", + use_reverberant_ref: bool = False, + num_spk: int = 1, + num_noise_type: int = 1, + sample_rate: int = 8000, + force_single_channel: bool = False, + channel_reordering: bool = False, + categories: Optional[List] = None, + data_aug_effects: List = None, + data_aug_num: List[int] = [1, 1], + data_aug_prob: float = 0.0, + speech_segment: Optional[int] = None, + avoid_allzero_segment: bool = True, + flexible_numspk: bool = False, + ): + super().__init__( + train=train, + token_type=None, + token_list=None, + bpemodel=None, + text_cleaner=None, + g2p_type=None, + unk_symbol="", + space_symbol="", + non_linguistic_symbols=None, + delimiter=None, + rir_scp=rir_scp, + rir_apply_prob=rir_apply_prob, + noise_scp=noise_scp, + noise_apply_prob=noise_apply_prob, + noise_db_range=noise_db_range, + short_noise_thres=short_noise_thres, + speech_volume_normalize=speech_volume_normalize, + speech_name=speech_name, + fs=sample_rate, + data_aug_effects=data_aug_effects, + data_aug_num=data_aug_num, + data_aug_prob=data_aug_prob, + ) + self.speech_ref_name_prefix = speech_ref_name_prefix + self.noise_ref_name_prefix = noise_ref_name_prefix + self.dereverb_ref_name_prefix = dereverb_ref_name_prefix + self.use_reverberant_ref = use_reverberant_ref + self.num_spk = num_spk + self.num_noise_type = num_noise_type + self.sample_rate = sample_rate + self.rir_scp = rir_scp + self.noise_scp = noise_scp + self.noise_db_range = noise_db_range + # Whether to always convert the signals to single-channel + self.force_single_channel = force_single_channel + # If True, randomly reorder the channels of the multi-channel signals + self.channel_reordering = channel_reordering + + # If specified, the audios will be chomped to the specified length + self.speech_segment = speech_segment + # Only used when `speech_segment` is specified. + # If True, make sure all chomped segments are not all-zero. + self.avoid_allzero_segment = avoid_allzero_segment + + # If True, load variable numbers of speakers in each sample, and + # self.num_spk is regarded as the maximum possible number of speakers + self.flexible_numspk = flexible_numspk + + # Map each category into a unique integer + self.categories = {} + if categories: + count = 0 + for c in categories: + if c not in self.categories: + self.categories[c] = count + count += 1 + + if self.speech_volume_normalize is not None: + sps = speech_volume_normalize.split("_") + if len(sps) == 1: + self.volume_low, self.volume_high = float(sps[0]) + elif len(sps) == 2: + self.volume_low, self.volume_high = float(sps[0]), float(sps[1]) + else: + raise ValueError( + "Format error for --speech_volume_normalize: " + f"'{speech_volume_normalize}'" + ) + + if (self.rirs is not None and self.rir_apply_prob > 0) or ( + self.noises is not None and self.noise_apply_prob > 0 + ): + logging.warning( + "Note: Please ensure the sampling rates of all data, including audios " + f"and RIRs, are all equal to {self.sample_rate} Hz when applying " + "dynamic mixing." + ) + + def __basic_str__(self): + msg = f", num_spk={self.num_spk}" + for key in ( + "force_single_channel", + "channel_reordering", + "speech_volume_normalize", + ): + if getattr(self, key): + msg += f", {key}={getattr(self, key)}" + if self.rirs is not None and self.rir_apply_prob > 0: + msg += f", sample_rate={self.sample_rate}" + msg += f", rir_scp={self.rir_scp}, rir_apply_prob={self.rir_apply_prob}" + if self.use_reverberant_ref: + msg += f", use_reverberant_ref={self.use_reverberant_ref}" + if self.noises is not None and self.noise_apply_prob > 0: + msg += f", noise_scp={self.noise_scp}" + msg += f", noise_apply_prob={self.noise_apply_prob}" + msg += f", noise_db_range={self.noise_db_range}" + if self.data_aug and self.data_aug_prob > 0: + msg += f", data_aug={self.data_aug}, data_aug_prob={self.data_aug_prob}" + if self.speech_segment: + msg += f", speech_segment={self.speech_segment}" + msg += f", avoid_allzero_segment={self.avoid_allzero_segment}" + if self.flexible_numspk: + msg += f", flexible_numspk={self.flexible_numspk}" + if self.categories: + if len(self.categories) <= 10: + msg += f", categories={self.categories}" + else: + msg += f", num_category={len(self.categories)}" + return msg + + def __repr__(self): + name = self.__class__.__module__ + "." + self.__class__.__name__ + msg = f"{name}(train={self.train}" + msg += self.__basic_str__() + return msg + ")" + + def _ensure_2d(self, signal): + if isinstance(signal, tuple): + return tuple(self._ensure_2d(sig) for sig in signal) + elif isinstance(signal, list): + return [self._ensure_2d(sig) for sig in signal] + else: + # (Nmic, Time) + return signal[None, :] if signal.ndim == 1 else signal.T + + def _get_early_signal(self, speech, rir, power): + predelay = 50 # milliseconds + dt = np.argmax(rir, axis=1).min() + et = dt + (predelay * self.sample_rate) // 1000 + rir_early = rir[:, :et] + speech2 = scipy.signal.convolve(speech, rir_early, mode="full")[ + :, : speech.shape[1] + ] + # Reverse mean power to the original power + power2 = (speech2[detect_non_silence(speech2)] ** 2).mean() + speech2 = np.sqrt(power / max(power2, 1e-10)) * speech2 + return speech2 + + def _apply_to_all_signals(self, data_dict, func, num_spk): + data_dict[self.speech_name] = func(data_dict[self.speech_name]) + + for n in range(self.num_noise_type): + noise_name = self.noise_ref_name_prefix + str(n + 1) + if noise_name in data_dict: + data_dict[noise_name] = func(data_dict[noise_name]) + + for spk in range(num_spk): + speech_ref_name = self.speech_ref_name_prefix + str(spk + 1) + if self.train or speech_ref_name in data_dict: + data_dict[speech_ref_name] = func(data_dict[speech_ref_name]) + + dereverb_ref_name = self.dereverb_ref_name_prefix + str(spk + 1) + if dereverb_ref_name in data_dict: + data_dict[dereverb_ref_name] = func(data_dict[dereverb_ref_name]) + + def _random_crop_range( + self, data_dict, num_spk, tgt_length, uid=None, max_trials=10 + ): + # Randomly crop the signals to the length `tgt_length` + assert tgt_length > 0, tgt_length + speech_refs = [ + data_dict[self.speech_ref_name_prefix + str(spk + 1)] + for spk in range(num_spk) + ] + length = speech_refs[0].shape[0] + if length <= tgt_length: + if length < tgt_length: + logging.warning( + f"The sample ({uid}) is not cropped due to its short length " + f"({length} < {tgt_length})." + ) + return 0, length + + start = np.random.randint(0, length - tgt_length) + count = 1 + if self.avoid_allzero_segment: + # try to find a segment region that ensures all references are non-allzero + while any_allzero([sf[start : start + tgt_length] for sf in speech_refs]): + count += 1 + if count > max_trials: + logging.warning( + f"Can't find non-allzero segments for all references in {uid}." + ) + break + if start > 0: + start = np.random.randint(0, start) + else: + break + return start, start + tgt_length + + def _speech_process( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, Union[str, np.ndarray]]: + assert check_argument_types() + + if self.speech_name not in data: + assert check_return_type(data) + return data + + num_spk = self.num_spk + + # Add the category information (an integer) to `data` + if not self.categories and "category" in data: + raise ValueError( + "categories must be set in the config file when utt2category files " + "exist in the data directory (e.g., dump/raw/*/utt2category)" + ) + + # Add the sampling rate information (an integer) to `data` + if "fs" in data: + fs = int(data.pop("fs")) + data["utt2fs"] = np.array([fs]) + else: + fs = self.sample_rate + + sref_name = self.speech_ref_name_prefix + "1" + if self.flexible_numspk and sref_name in data: + # The number of speaker varies in each sample. + # Different speaker signals are stacked in the first dimension. + dref_name = self.dereverb_ref_name_prefix + "1" + num_spk = len(data[sref_name]) + for i in range(2, self.num_spk + 1): + data.pop(self.speech_ref_name_prefix + str(i), None) + data.pop(self.dereverb_ref_name_prefix + str(i), None) + # Divide the stacked signals into single speaker signals for consistency + for i in range(num_spk - 1, -1, -1): + idx = str(i + 1) + # make sure no np.nan paddings are in the data + assert not np.isnan(np.sum(data[sref_name][i])), uid + data[self.speech_ref_name_prefix + idx] = data[sref_name][i] + if dref_name in data: + # make sure no np.nan paddings are in the data + assert not np.isnan(np.sum(data[dref_name][i])), uid + data[self.dereverb_ref_name_prefix + idx] = data[dref_name][i] + + if self.train: + if self.speech_segment is not None: + speech_segment = self.speech_segment // self.sample_rate * fs + start, end = self._random_crop_range( + data, num_spk, speech_segment, uid=uid + ) + self._apply_to_all_signals(data, lambda x: x[start:end], num_spk) + # clean speech signal (Nmic, Time) + speech_ref = [ + self._ensure_2d(data[self.speech_ref_name_prefix + str(i + 1)]) + for i in range(num_spk) + ] + + # dereverberated (noisy) signal (Nmic, Time) + if self.dereverb_ref_name_prefix + "1" in data: + dereverb_speech_ref = [ + self._ensure_2d(data[self.dereverb_ref_name_prefix + str(i + 1)]) + for i in range(num_spk) + if self.dereverb_ref_name_prefix + str(i + 1) in data + ] + assert len(dereverb_speech_ref) in (1, num_spk), len( + dereverb_speech_ref + ) + else: + dereverb_speech_ref = None + + # Calc power on non silence region + power_ref = [ + (sref[detect_non_silence(sref)] ** 2).mean() for sref in speech_ref + ] + + speech_mix = self._ensure_2d(data[self.speech_name]) + # 1. Convolve RIR + if self.rirs is not None and self.rir_apply_prob >= np.random.random(): + speech_ref, rir_ref = zip( + *[ + self._convolve_rir( + sp, + power, + self.rirs, + tgt_fs=fs, + single_channel=self.force_single_channel, + ) + for sp, power in zip(speech_ref, power_ref) + ] + ) + if self.force_single_channel: + speech_ref = list(map(lambda x: x[:1], speech_ref)) + rir_ref = list(map(lambda x: x[:1], rir_ref)) + + if self.use_reverberant_ref: + for spk in range(num_spk): + suffix = str(spk + 1) + speech_ref_name = self.speech_ref_name_prefix + suffix + # (Time, Nmic) + data[speech_ref_name] = speech_ref[spk].T + + if dereverb_speech_ref is not None: + if spk == 0 or len(dereverb_speech_ref) > 1: + dereverb_name = self.dereverb_ref_name_prefix + suffix + data[dereverb_name] = self._get_early_signal( + speech_ref[spk], rir_ref[spk], power_ref[spk] + ).T + else: + for spk in range(num_spk): + suffix = str(spk + 1) + speech_ref_name = self.speech_ref_name_prefix + suffix + # clean speech with early reflections (Time, Nmic) + data[speech_ref_name] = self._get_early_signal( + speech_ref[spk], rir_ref[spk], power_ref[spk] + ).T + + if dereverb_speech_ref is not None: + if spk == 0 or len(dereverb_speech_ref) > 1: + dereverb_name = self.dereverb_ref_name_prefix + suffix + data[dereverb_name] = data[speech_ref_name] + + if self.noise_ref_name_prefix + "1" in data: + noise = data[self.noise_ref_name_prefix + "1"] + speech_mix = sum(speech_ref) + noise + else: + speech_mix = sum(speech_ref) + + # Add category information for dynamic mixing + # "_reverb" means dereverberation is required + # "_both" means both reverberant and dereverberated signals are required + if "category" in data: + if self.use_reverberant_ref: + if dereverb_speech_ref is None: + if data["category"].endswith("_reverb"): + data["category"] = data["category"][:-7] + if data["category"].endswith("_both"): + data["category"] = data["category"][:-5] + else: + if not data["category"].endswith("_both"): + data["category"] = data["category"] + "_both" + elif not data["category"].endswith("_reverb"): + data["category"] = data["category"] + "_reverb" + + # 2. Add Noise + if self.noises is not None and self.noise_apply_prob >= np.random.random(): + speech_mix = sum(speech_ref) + if self.force_single_channel and speech_mix.shape[0] > 1: + speech_mix = speech_mix[:1] + + power_mix = (speech_mix[detect_non_silence(speech_mix)] ** 2).mean() + speech_mix, noise = self._add_noise( + speech_mix, + power_mix, + self.noises, + self.noise_db_low, + self.noise_db_high, + tgt_fs=fs, + single_channel=self.force_single_channel, + ) + + name = self.noise_ref_name_prefix + "1" + if name in data: + data[name] = noise.T + for n in range(1, self.num_noise_type): + name = self.noise_ref_name_prefix + str(n + 1) + data.pop(name, None) + + if self.data_aug: + if self.data_aug_prob > 0 and self.data_aug_prob >= np.random.random(): + # Currently, we only apply data augmentation to the mixture. + # So, some effects should not be used for Enh, such as pitch_shift, + # speed_perturb, time_stretch, polarity_inverse, reverse, etc. + speech_mix = self.data_aug( + speech_mix.T if speech_mix.shape[0] > 1 else speech_mix[0], + self.sample_rate, + ) + + data[self.speech_name] = speech_mix.T + ma = np.max(np.abs(data[self.speech_name])) + if ma > 1.0: + self._apply_to_all_signals(data, lambda x: x / ma, num_spk) + + self._apply_to_all_signals(data, lambda x: x.squeeze(), num_spk) + + if self.force_single_channel: + self._apply_to_all_signals( + data, lambda x: x if x.ndim == 1 else x[:, 0], num_spk + ) + + if self.speech_volume_normalize is not None: + if self.train: + volume_scale = np.random.uniform(self.volume_low, self.volume_high) + else: + # use a fixed scale to make it deterministic + volume_scale = self.volume_low + ma = np.max(np.abs(data[self.speech_name])) + self._apply_to_all_signals(data, lambda x: x * volume_scale / ma, num_spk) + + if self.categories and "category" in data: + category = data.pop("category") + if not re.fullmatch(r"\d+ch.*", category): + speech_mix = data[self.speech_name] + nch = 1 if speech_mix.ndim == 1 else speech_mix.shape[-1] + category = f"{nch}ch_" + category + assert category in self.categories, category + data["utt2category"] = np.array([self.categories[category]]) + + speech_mix = data[self.speech_name] + # Reorder channels of the multi-channel signals + if speech_mix.ndim > 1 and self.channel_reordering and self.train: + num_ch = speech_mix.shape[-1] + # chs = np.random.choice(range(num_ch), size=num_ch, replace=False).tolist() + chs = np.random.permutation(num_ch).tolist() + data[self.speech_name] = speech_mix[..., chs] + for i in range(num_spk): + k = self.speech_ref_name_prefix + str(i + 1) + if self.train: + assert k in data, (data.keys(), k) + if k in data and data[k].ndim > 1: + assert data[k].shape == speech_mix.shape + data[k] = data[k][..., chs] + + assert check_return_type(data) + return data + + def __call__( + self, uid: str, data: Dict[str, Union[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + assert check_argument_types() + + data = self._speech_process(uid, data) + data = self._text_process(data) + return data