From 74f71bb5fcd6962c109ed8021b9829e1e521c311 Mon Sep 17 00:00:00 2001 From: Satrajit Ghosh Date: Thu, 11 Apr 2024 20:24:00 -0400 Subject: [PATCH 1/3] ref: plotting to improve speed --- .flake8 | 1 + src/b2aiprep/cli.py | 18 +++++- src/b2aiprep/process.py | 140 +++++++++++++++++++++++++++++++++------- 3 files changed, 134 insertions(+), 25 deletions(-) diff --git a/.flake8 b/.flake8 index 7da1f96..670e3f8 100644 --- a/.flake8 +++ b/.flake8 @@ -1,2 +1,3 @@ [flake8] max-line-length = 100 +ignore = E203 diff --git a/src/b2aiprep/cli.py b/src/b2aiprep/cli.py index 537da76..f038a69 100644 --- a/src/b2aiprep/cli.py +++ b/src/b2aiprep/cli.py @@ -40,6 +40,8 @@ def main(): @click.option("--save_figures/--no-save_figures", default=False, show_default=True) @click.option("--n_mels", type=int, default=20, show_default=True) @click.option("--n_coeff", type=int, default=20, show_default=True) +@click.option("--win_length", type=int, default=20, show_default=True) +@click.option("--hop_length", type=int, default=10, show_default=True) @click.option("--compute_deltas/--no-compute_deltas", default=True, show_default=True) @click.option("--speech2text/--no-speech2text", type=bool, default=False, show_default=True) @click.option("--opensmile", nargs=2, default=["eGeMAPSv02", "Functionals"], show_default=True) @@ -51,6 +53,8 @@ def convert( save_figures, n_mels, n_coeff, + win_length, + hop_length, compute_deltas, speech2text, opensmile, @@ -65,9 +69,12 @@ def convert( extract_text=speech2text, n_mels=n_mels, n_coeff=n_coeff, + win_length=win_length, + hop_length=hop_length, compute_deltas=compute_deltas, opensmile_feature_set=opensmile[0], opensmile_feature_level=opensmile[1], + device="cpu", ) @@ -77,6 +84,8 @@ def convert( @click.option("--save_figures/--no-save_figures", default=False, show_default=True) @click.option("--n_mels", type=int, default=20, show_default=True) @click.option("--n_coeff", type=int, default=20, show_default=True) +@click.option("--win_length", type=int, default=20, show_default=True) +@click.option("--hop_length", type=int, default=10, show_default=True) @click.option("--compute_deltas/--no-compute_deltas", default=True, show_default=True) @click.option( "-p", @@ -102,6 +111,8 @@ def batchconvert( save_figures, n_mels, n_coeff, + win_length, + hop_length, compute_deltas, plugin, cache, @@ -120,12 +131,15 @@ def batchconvert( featurize_task = featurize_pdt( n_mels=n_mels, n_coeff=n_coeff, + win_length=win_length, + hop_length=hop_length, compute_deltas=compute_deltas, cache_dir=Path(cache).absolute(), save_figures=save_figures, extract_text=speech2text, opensmile_feature_set=opensmile[0], opensmile_feature_level=opensmile[1], + device="cpu", ) with open(csvfile, "r") as f: @@ -187,7 +201,7 @@ def gen(): yield torch.load(val) print(f"Input: {len(results)} files. Processed: {len(stored_results)}") - to_hf_dataset(gen, Path(outdir) / "hf_dataset") + to_hf_dataset(gen, Path(outdir)) @main.command() @@ -292,7 +306,7 @@ def createbatchcsv(input_dir, out_file): # out_file is where a csv file will be saved and should be in the format 'path/name/csv' input_dir = Path(input_dir) - audiofiles = glob(f"{input_dir}/**/*.wav", recursive=True) + audiofiles = sorted(glob(f"{input_dir}/**/*.wav", recursive=True)) with open(out_file, "w") as f: diff --git a/src/b2aiprep/process.py b/src/b2aiprep/process.py index cc8f862..4518a98 100644 --- a/src/b2aiprep/process.py +++ b/src/b2aiprep/process.py @@ -95,7 +95,11 @@ def verify_speaker_from_files( def specgram( - audio: Audio, n_fft: int = 512, win_length: int = 20, hop_length: int = 10, toDb: bool = False + audio: Audio, + n_fft: ty.Optional[int] = None, + win_length: int = 20, + hop_length: int = 10, + toDb: bool = False, ) -> torch.tensor: """Compute the spectrogram using STFT of the audio signal @@ -110,7 +114,7 @@ def specgram( sample_rate=audio.sample_rate, win_length=win_length, hop_length=hop_length, - n_fft=n_fft or int(400 * audio.sample_rate / 16000), + n_fft=n_fft or int(win_length * audio.sample_rate / 1000), ) stft = compute_STFT(audio.signal.unsqueeze(0)) spec = spf.spectral_magnitude(stft.squeeze(), power=1) @@ -191,6 +195,87 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None, **kwargs) ax.matshow(specgram, origin="lower", aspect="auto", **kwargs) # , interpolation="nearest") +def plot_save_figure(audio, log_spec, prefix, outdir): + duration = len(audio.signal) / audio.sample_rate + win_length = duration + # for large plots determine a balance between the number of suplots + # and padding required. + if duration > 20: + nplots = 100 + for win_length_try in range(10, 20): + nplots_try = int(torch.ceil(torch.tensor(duration) / win_length_try)) + if nplots_try < nplots: + nplots = nplots_try + win_length = win_length_try + else: + nplots = 1 + freq_steps = 4 + freq_ticks = torch.arange(0, log_spec.shape[0], log_spec.shape[0] // (freq_steps - 1)) + freq_axis = ( + (torch.round(freq_ticks * audio.sample_rate / 2 / log_spec.shape[0])).numpy().astype(int) + ) + + signal = audio.signal + sr = audio.sample_rate + # This factor is used to decimate the waveform, which provides + # the biggest speedup. + decimate_factor = 4 ** (int(len(signal) // (win_length * sr))) + signal = signal[::decimate_factor] + sr = sr // decimate_factor + + if nplots > 1: + # pad signal and spectrogram to fill up plot + signal_pad = (nplots * win_length * sr) - len(signal) + signal = torch.nn.ZeroPad1d((0, signal_pad))(signal.T).T + spec_pad = int((len(signal) / sr) / (duration / log_spec.shape[1])) - log_spec.shape[1] + log_spec = torch.nn.ZeroPad2d((0, spec_pad))(log_spec) + + N = len(signal) + ymax = torch.abs(signal).max() + + fig, axs = plt.subplots( + nplots * 2, + 1, + figsize=(8, 3 * nplots), + sharex=True, + ) + fig.subplots_adjust(hspace=0.0) + for idx in range(nplots): + if nplots > 1: + waveform = signal[(N // nplots * idx) : (N // nplots * (idx + 1))] + spec = log_spec[ + :, (log_spec.shape[1] // nplots * idx) : (log_spec.shape[1] // nplots * (idx + 1)) + ] + else: + waveform = signal + spec = log_spec + # Plot waveform + ax = axs[2 * idx + 0] + ax.set_ylim([-ymax, ymax]) + timestamps = torch.arange(0, len(waveform)) / len(waveform) * spec.shape[1] + ax.plot(timestamps, waveform, linewidth=1) + ax.grid(True) + # Plot spectrogram + ax = axs[2 * idx + 1] + ax.matshow(spec, origin="lower", aspect="auto") + ax.set_yticks(freq_ticks) + ax.set_yticklabels(freq_axis) + ax.set_ylabel("Freq (Hz)") + if idx == nplots - 1: + ax.xaxis.set_ticks_position("bottom") + xticks = torch.arange(0, spec.shape[1], spec.shape[1] // win_length) + xticklabels = torch.round(xticks / spec.shape[1] * win_length).numpy().astype(int) + ax.set_xticks(xticks) + ax.set_xticklabels(xticklabels) + ax.set_xlabel("Time (s)") + fig.suptitle(f"Waveform and spectrogram of {prefix}") + fig.tight_layout() + outfig = outdir / f"{prefix}_specgram.png" + fig.savefig(outfig, bbox_inches="tight") + plt.close(fig) + return outfig + + def to_features( filename: Path, subject: ty.Optional[str] = None, @@ -199,6 +284,8 @@ def to_features( save_figures: bool = False, stt_kwargs: ty.Optional[ty.Dict] = None, extract_text: bool = False, + win_length: int = 20, + hop_length: int = 10, n_mels: int = 20, n_coeff: int = 20, compute_deltas: bool = True, @@ -206,6 +293,7 @@ def to_features( opensmile_feature_level: str = "Functionals", return_features: bool = False, mpl_backend: str = "Agg", + device: ty.Optional[str] = None, ) -> ty.Tuple[ty.Dict, Path, ty.Optional[Path]]: """Compute features from audio file @@ -215,6 +303,8 @@ def to_features( :param outdir: Output directory :param save_figures: Whether to save figures :param extract_text: Whether to extract text + :param win_length: Window length (ms) + :param hop_length: Hop length (ms) :param stt_kwargs: Keyword arguments for SpeechToText :param n_mels: Number of Mel bands :param n_coeff: Number of MFCC coefficients @@ -223,6 +313,7 @@ def to_features( :param opensmile_feature_level: OpenSmile feature level :param return_features: Whether to return features :param mpl_backend: matplotlib backend + :param device: Acceleration device (e.g. "cuda" or "cpu" or "mps") :return: Features dictionary :return: Path to features :return: Path to figures @@ -235,7 +326,8 @@ def to_features( md5sum = md5(f.read()).hexdigest() audio = Audio.from_file(str(filename)) audio = audio.to_16khz() - features_specgram = specgram(audio) + # set window and hop length to the same to not allow for good Griffin Lim reconstruction + features_specgram = specgram(audio, win_length=win_length, hop_length=hop_length) features_melfilterbank = melfilterbank(features_specgram, n_mels=n_mels) features_mfcc = MFCC(features_melfilterbank, n_coeff=n_coeff, compute_deltas=compute_deltas) features_opensmile = extract_opensmile(audio, opensmile_feature_set, opensmile_feature_level) @@ -249,12 +341,16 @@ def to_features( "checksum": md5sum, } if extract_text: + # Select the best device available + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + device = "mps" if torch.backends.mps.is_available() else device stt_kwargs_default = { "model_id": "openai/whisper-base", "max_new_tokens": 128, "chunk_length_s": 30, "batch_size": 16, - "device": None, + "device": device, } if stt_kwargs is not None: stt_kwargs_default.update(**stt_kwargs) @@ -278,20 +374,15 @@ def to_features( outfig = None if save_figures: # save spectogram as figure - - # general log spectrogram for image - features_specgram_log = specgram(audio, toDb=True) - - fig, axs = plt.subplots(2, 1) - plot_waveform( - audio.signal, sr=audio.sample_rate, title=f"Original waveform: {prefix}", ax=axs[0] - ) - plot_spectrogram(features_specgram_log.T, title="spectrogram", ax=axs[1]) - fig.tight_layout() - - outfig = outdir / f"{prefix}_specgram.png" - fig.savefig(outfig, bbox_inches="tight") - plt.close(fig) + """ + log_spec = specgram(audio, + win_length=20, + hop_length=10, + toDb=True) + """ + log_spec = 10.0 * torch.log10(torch.maximum(features_specgram, torch.tensor(1e-10))) + log_spec = torch.maximum(log_spec, log_spec.max() - 80) + outfig = plot_save_figure(audio, log_spec.T, prefix, outdir) return features if return_features else None, outfile, outfig @@ -361,9 +452,13 @@ def __init__( ) -> None: torch_dtype = torch.float32 - if device is not None and "cuda" in device: - # If CUDA is available, set use_gpu to True - if torch.cuda.is_available(): + if device is not None: + if device == "cpu": + pass + elif device == "cuda" and torch.cuda.is_available(): + # If CUDA is available, set use_gpu to True + torch_dtype = torch.float16 + elif device == "mps" and torch.backends.mps.is_available(): torch_dtype = torch.float16 # If CUDA is not available, raise an error else: @@ -415,5 +510,4 @@ def transcribe(self, audio: Audio, language: str = None): def to_hf_dataset(generator, outdir: Path) -> None: # Create a Hugging Face dataset from the data ds = Dataset.from_generator(generator) - # Save the dataset to a JSON file - ds.save_to_disk(outdir) + ds.to_parquet(outdir / "b2aivoice.parquet") From 57720de8bb9aa39fb1a9d29689f0b341b507b53d Mon Sep 17 00:00:00 2001 From: Satrajit Ghosh Date: Thu, 11 Apr 2024 21:24:53 -0400 Subject: [PATCH 2/3] fix: keep decimation factor limited --- src/b2aiprep/process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/b2aiprep/process.py b/src/b2aiprep/process.py index 4518a98..ecc5d51 100644 --- a/src/b2aiprep/process.py +++ b/src/b2aiprep/process.py @@ -219,7 +219,7 @@ def plot_save_figure(audio, log_spec, prefix, outdir): sr = audio.sample_rate # This factor is used to decimate the waveform, which provides # the biggest speedup. - decimate_factor = 4 ** (int(len(signal) // (win_length * sr))) + decimate_factor = min(4 ** (int(len(signal) // (win_length * sr))), 100) signal = signal[::decimate_factor] sr = sr // decimate_factor From f08b9ba46857966d7dcc2055f57d5913c8583797 Mon Sep 17 00:00:00 2001 From: Satrajit Ghosh Date: Thu, 11 Apr 2024 21:32:03 -0400 Subject: [PATCH 3/3] lower decimation factor --- src/b2aiprep/process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/b2aiprep/process.py b/src/b2aiprep/process.py index ecc5d51..1d700de 100644 --- a/src/b2aiprep/process.py +++ b/src/b2aiprep/process.py @@ -219,7 +219,7 @@ def plot_save_figure(audio, log_spec, prefix, outdir): sr = audio.sample_rate # This factor is used to decimate the waveform, which provides # the biggest speedup. - decimate_factor = min(4 ** (int(len(signal) // (win_length * sr))), 100) + decimate_factor = min(4 ** (int(len(signal) // (win_length * sr))), 80) signal = signal[::decimate_factor] sr = sr // decimate_factor