Skip to content

Commit

Permalink
Merge pull request #35 from sensein/enh/speedup
Browse files Browse the repository at this point in the history
ref: plotting to improve speed
  • Loading branch information
satra authored Apr 12, 2024
2 parents d47f5fc + f08b9ba commit ac97893
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 25 deletions.
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
[flake8]
max-line-length = 100
ignore = E203
18 changes: 16 additions & 2 deletions src/b2aiprep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def main():
@click.option("--save_figures/--no-save_figures", default=False, show_default=True)
@click.option("--n_mels", type=int, default=20, show_default=True)
@click.option("--n_coeff", type=int, default=20, show_default=True)
@click.option("--win_length", type=int, default=20, show_default=True)
@click.option("--hop_length", type=int, default=10, show_default=True)
@click.option("--compute_deltas/--no-compute_deltas", default=True, show_default=True)
@click.option("--speech2text/--no-speech2text", type=bool, default=False, show_default=True)
@click.option("--opensmile", nargs=2, default=["eGeMAPSv02", "Functionals"], show_default=True)
Expand All @@ -51,6 +53,8 @@ def convert(
save_figures,
n_mels,
n_coeff,
win_length,
hop_length,
compute_deltas,
speech2text,
opensmile,
Expand All @@ -65,9 +69,12 @@ def convert(
extract_text=speech2text,
n_mels=n_mels,
n_coeff=n_coeff,
win_length=win_length,
hop_length=hop_length,
compute_deltas=compute_deltas,
opensmile_feature_set=opensmile[0],
opensmile_feature_level=opensmile[1],
device="cpu",
)


Expand All @@ -77,6 +84,8 @@ def convert(
@click.option("--save_figures/--no-save_figures", default=False, show_default=True)
@click.option("--n_mels", type=int, default=20, show_default=True)
@click.option("--n_coeff", type=int, default=20, show_default=True)
@click.option("--win_length", type=int, default=20, show_default=True)
@click.option("--hop_length", type=int, default=10, show_default=True)
@click.option("--compute_deltas/--no-compute_deltas", default=True, show_default=True)
@click.option(
"-p",
Expand All @@ -102,6 +111,8 @@ def batchconvert(
save_figures,
n_mels,
n_coeff,
win_length,
hop_length,
compute_deltas,
plugin,
cache,
Expand All @@ -120,12 +131,15 @@ def batchconvert(
featurize_task = featurize_pdt(
n_mels=n_mels,
n_coeff=n_coeff,
win_length=win_length,
hop_length=hop_length,
compute_deltas=compute_deltas,
cache_dir=Path(cache).absolute(),
save_figures=save_figures,
extract_text=speech2text,
opensmile_feature_set=opensmile[0],
opensmile_feature_level=opensmile[1],
device="cpu",
)

with open(csvfile, "r") as f:
Expand Down Expand Up @@ -187,7 +201,7 @@ def gen():
yield torch.load(val)

print(f"Input: {len(results)} files. Processed: {len(stored_results)}")
to_hf_dataset(gen, Path(outdir) / "hf_dataset")
to_hf_dataset(gen, Path(outdir))


@main.command()
Expand Down Expand Up @@ -292,7 +306,7 @@ def createbatchcsv(input_dir, out_file):

# out_file is where a csv file will be saved and should be in the format 'path/name/csv'
input_dir = Path(input_dir)
audiofiles = glob(f"{input_dir}/**/*.wav", recursive=True)
audiofiles = sorted(glob(f"{input_dir}/**/*.wav", recursive=True))

with open(out_file, "w") as f:

Expand Down
140 changes: 117 additions & 23 deletions src/b2aiprep/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def verify_speaker_from_files(


def specgram(
audio: Audio, n_fft: int = 512, win_length: int = 20, hop_length: int = 10, toDb: bool = False
audio: Audio,
n_fft: ty.Optional[int] = None,
win_length: int = 20,
hop_length: int = 10,
toDb: bool = False,
) -> torch.tensor:
"""Compute the spectrogram using STFT of the audio signal
Expand All @@ -110,7 +114,7 @@ def specgram(
sample_rate=audio.sample_rate,
win_length=win_length,
hop_length=hop_length,
n_fft=n_fft or int(400 * audio.sample_rate / 16000),
n_fft=n_fft or int(win_length * audio.sample_rate / 1000),
)
stft = compute_STFT(audio.signal.unsqueeze(0))
spec = spf.spectral_magnitude(stft.squeeze(), power=1)
Expand Down Expand Up @@ -191,6 +195,87 @@ def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None, **kwargs)
ax.matshow(specgram, origin="lower", aspect="auto", **kwargs) # , interpolation="nearest")


def plot_save_figure(audio, log_spec, prefix, outdir):
duration = len(audio.signal) / audio.sample_rate
win_length = duration
# for large plots determine a balance between the number of suplots
# and padding required.
if duration > 20:
nplots = 100
for win_length_try in range(10, 20):
nplots_try = int(torch.ceil(torch.tensor(duration) / win_length_try))
if nplots_try < nplots:
nplots = nplots_try
win_length = win_length_try
else:
nplots = 1
freq_steps = 4
freq_ticks = torch.arange(0, log_spec.shape[0], log_spec.shape[0] // (freq_steps - 1))
freq_axis = (
(torch.round(freq_ticks * audio.sample_rate / 2 / log_spec.shape[0])).numpy().astype(int)
)

signal = audio.signal
sr = audio.sample_rate
# This factor is used to decimate the waveform, which provides
# the biggest speedup.
decimate_factor = min(4 ** (int(len(signal) // (win_length * sr))), 80)
signal = signal[::decimate_factor]
sr = sr // decimate_factor

if nplots > 1:
# pad signal and spectrogram to fill up plot
signal_pad = (nplots * win_length * sr) - len(signal)
signal = torch.nn.ZeroPad1d((0, signal_pad))(signal.T).T
spec_pad = int((len(signal) / sr) / (duration / log_spec.shape[1])) - log_spec.shape[1]
log_spec = torch.nn.ZeroPad2d((0, spec_pad))(log_spec)

N = len(signal)
ymax = torch.abs(signal).max()

fig, axs = plt.subplots(
nplots * 2,
1,
figsize=(8, 3 * nplots),
sharex=True,
)
fig.subplots_adjust(hspace=0.0)
for idx in range(nplots):
if nplots > 1:
waveform = signal[(N // nplots * idx) : (N // nplots * (idx + 1))]
spec = log_spec[
:, (log_spec.shape[1] // nplots * idx) : (log_spec.shape[1] // nplots * (idx + 1))
]
else:
waveform = signal
spec = log_spec
# Plot waveform
ax = axs[2 * idx + 0]
ax.set_ylim([-ymax, ymax])
timestamps = torch.arange(0, len(waveform)) / len(waveform) * spec.shape[1]
ax.plot(timestamps, waveform, linewidth=1)
ax.grid(True)
# Plot spectrogram
ax = axs[2 * idx + 1]
ax.matshow(spec, origin="lower", aspect="auto")
ax.set_yticks(freq_ticks)
ax.set_yticklabels(freq_axis)
ax.set_ylabel("Freq (Hz)")
if idx == nplots - 1:
ax.xaxis.set_ticks_position("bottom")
xticks = torch.arange(0, spec.shape[1], spec.shape[1] // win_length)
xticklabels = torch.round(xticks / spec.shape[1] * win_length).numpy().astype(int)
ax.set_xticks(xticks)
ax.set_xticklabels(xticklabels)
ax.set_xlabel("Time (s)")
fig.suptitle(f"Waveform and spectrogram of {prefix}")
fig.tight_layout()
outfig = outdir / f"{prefix}_specgram.png"
fig.savefig(outfig, bbox_inches="tight")
plt.close(fig)
return outfig


def to_features(
filename: Path,
subject: ty.Optional[str] = None,
Expand All @@ -199,13 +284,16 @@ def to_features(
save_figures: bool = False,
stt_kwargs: ty.Optional[ty.Dict] = None,
extract_text: bool = False,
win_length: int = 20,
hop_length: int = 10,
n_mels: int = 20,
n_coeff: int = 20,
compute_deltas: bool = True,
opensmile_feature_set: str = "eGeMAPSv02",
opensmile_feature_level: str = "Functionals",
return_features: bool = False,
mpl_backend: str = "Agg",
device: ty.Optional[str] = None,
) -> ty.Tuple[ty.Dict, Path, ty.Optional[Path]]:
"""Compute features from audio file
Expand All @@ -215,6 +303,8 @@ def to_features(
:param outdir: Output directory
:param save_figures: Whether to save figures
:param extract_text: Whether to extract text
:param win_length: Window length (ms)
:param hop_length: Hop length (ms)
:param stt_kwargs: Keyword arguments for SpeechToText
:param n_mels: Number of Mel bands
:param n_coeff: Number of MFCC coefficients
Expand All @@ -223,6 +313,7 @@ def to_features(
:param opensmile_feature_level: OpenSmile feature level
:param return_features: Whether to return features
:param mpl_backend: matplotlib backend
:param device: Acceleration device (e.g. "cuda" or "cpu" or "mps")
:return: Features dictionary
:return: Path to features
:return: Path to figures
Expand All @@ -235,7 +326,8 @@ def to_features(
md5sum = md5(f.read()).hexdigest()
audio = Audio.from_file(str(filename))
audio = audio.to_16khz()
features_specgram = specgram(audio)
# set window and hop length to the same to not allow for good Griffin Lim reconstruction
features_specgram = specgram(audio, win_length=win_length, hop_length=hop_length)
features_melfilterbank = melfilterbank(features_specgram, n_mels=n_mels)
features_mfcc = MFCC(features_melfilterbank, n_coeff=n_coeff, compute_deltas=compute_deltas)
features_opensmile = extract_opensmile(audio, opensmile_feature_set, opensmile_feature_level)
Expand All @@ -249,12 +341,16 @@ def to_features(
"checksum": md5sum,
}
if extract_text:
# Select the best device available
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = "mps" if torch.backends.mps.is_available() else device
stt_kwargs_default = {
"model_id": "openai/whisper-base",
"max_new_tokens": 128,
"chunk_length_s": 30,
"batch_size": 16,
"device": None,
"device": device,
}
if stt_kwargs is not None:
stt_kwargs_default.update(**stt_kwargs)
Expand All @@ -278,20 +374,15 @@ def to_features(
outfig = None
if save_figures:
# save spectogram as figure

# general log spectrogram for image
features_specgram_log = specgram(audio, toDb=True)

fig, axs = plt.subplots(2, 1)
plot_waveform(
audio.signal, sr=audio.sample_rate, title=f"Original waveform: {prefix}", ax=axs[0]
)
plot_spectrogram(features_specgram_log.T, title="spectrogram", ax=axs[1])
fig.tight_layout()

outfig = outdir / f"{prefix}_specgram.png"
fig.savefig(outfig, bbox_inches="tight")
plt.close(fig)
"""
log_spec = specgram(audio,
win_length=20,
hop_length=10,
toDb=True)
"""
log_spec = 10.0 * torch.log10(torch.maximum(features_specgram, torch.tensor(1e-10)))
log_spec = torch.maximum(log_spec, log_spec.max() - 80)
outfig = plot_save_figure(audio, log_spec.T, prefix, outdir)

return features if return_features else None, outfile, outfig

Expand Down Expand Up @@ -361,9 +452,13 @@ def __init__(
) -> None:

torch_dtype = torch.float32
if device is not None and "cuda" in device:
# If CUDA is available, set use_gpu to True
if torch.cuda.is_available():
if device is not None:
if device == "cpu":
pass
elif device == "cuda" and torch.cuda.is_available():
# If CUDA is available, set use_gpu to True
torch_dtype = torch.float16
elif device == "mps" and torch.backends.mps.is_available():
torch_dtype = torch.float16
# If CUDA is not available, raise an error
else:
Expand Down Expand Up @@ -415,5 +510,4 @@ def transcribe(self, audio: Audio, language: str = None):
def to_hf_dataset(generator, outdir: Path) -> None:
# Create a Hugging Face dataset from the data
ds = Dataset.from_generator(generator)
# Save the dataset to a JSON file
ds.save_to_disk(outdir)
ds.to_parquet(outdir / "b2aivoice.parquet")

0 comments on commit ac97893

Please sign in to comment.