Skip to content

Commit

Permalink
Merge pull request #170 from sensein/plot2
Browse files Browse the repository at this point in the history
Adjusting utilities for plotting spectrograms
  • Loading branch information
fabiocat93 authored Oct 1, 2024
2 parents b278c73 + dbc68cd commit a1f7ea9
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 11 deletions.
8 changes: 4 additions & 4 deletions src/senselab/audio/tasks/features_extraction/torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

def extract_spectrogram_from_audios(
audios: List[Audio],
n_fft: int = 400,
n_fft: int = 1024,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
) -> List[Dict[str, torch.Tensor]]:
"""Extract spectrograms from a list of audio objects.
Args:
audios (List[Audio]): List of Audio objects.
n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 400.
n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 1024.
win_length (int): Window size. Default is None, using n_fft.
hop_length (int): Length of hop between STFT windows. Default is None, using win_length // 2.
Expand All @@ -42,7 +42,7 @@ def extract_spectrogram_from_audios(

def extract_mel_spectrogram_from_audios(
audios: List[Audio],
n_fft: Optional[int] = 400,
n_fft: Optional[int] = 1024,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
n_mels: int = 128,
Expand All @@ -51,7 +51,7 @@ def extract_mel_spectrogram_from_audios(
Args:
audios (List[Audio]): List of Audio objects.
n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 400.
n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 1024.
win_length (int): Window size. Default is None, using n_fft.
hop_length (int): Length of hop between STFT windows. Default is None, using win_length // 2.
n_mels (int): Number of mel filter banks. Default is 128.
Expand Down
3 changes: 3 additions & 0 deletions src/senselab/audio/tasks/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""This module contains functions for plotting audio-related data."""

from .plotting import play_audio, plot_specgram, plot_waveform # noqa: F401
90 changes: 84 additions & 6 deletions src/senselab/audio/tasks/plotting/plotting.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""This module contains functions for plotting audio-related data."""

from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import Audio, display

from senselab.audio.data_structures import Audio as AudioData
from senselab.audio.tasks.features_extraction.torchaudio import (
extract_mel_spectrogram_from_audios,
extract_spectrogram_from_audios,
)
from senselab.utils.data_structures import logger


def plot_waveform(audio: AudioData, title: str = "Waveform", fast: bool = False) -> None:
Expand Down Expand Up @@ -42,31 +50,101 @@ def plot_waveform(audio: AudioData, title: str = "Waveform", fast: bool = False)
plt.show(block=False)


def plot_specgram(spectrogram: torch.Tensor, sample_rate: int, title: str = "Spectrogram") -> None:
def plot_specgram(audio: AudioData, mel_scale: bool = False, title: str = "Spectrogram", **spect_kwargs: Any) -> None: # noqa : ANN401
"""Plots the spectrogram of an Audio object.
Args:
spectrogram (torch.Tensor): A tensor representing the spectrogram.
sample_rate (int): The sampling rate of the audio data.
audio: (AudioData): An instance of Audio containing waveform data and sampling rate.
mel_scale (bool): Whether to plot a mel spectrogram or a regular spectrogram.
title (str): Title of the spectrogram plot.
**spect_kwargs: Additional keyword arguments to pass to the spectrogram function.
Todo:
- Add option to save the plot
- Add option to choose the size of the Figure
"""

def _power_to_db(
spectrogram: np.ndarray, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0
) -> np.ndarray:
"""Converts a power spectrogram (amplitude squared) to decibel (dB) units.
Args:
spectrogram (np.ndarray): Power spectrogram to convert.
ref (float): Reference power. Default is 1.0.
amin (float): Minimum power. Default is 1e-10.
top_db (float): Minimum decibel. Default is 80.0.
Returns:
np.ndarray: Decibel spectrogram.
"""
S = np.asarray(spectrogram)

if amin <= 0:
raise ValueError("amin must be strictly positive")

if np.issubdtype(S.dtype, np.complexfloating):
logger.warning(
"_power_to_db was called on complex input so phase "
"information will be discarded. To suppress this warning, "
"call power_to_db(np.abs(D)**2) instead.",
stacklevel=2,
)
magnitude = np.abs(S)
else:
magnitude = S

if callable(ref):
# User supplied a function to calculate reference power
ref_value = ref(magnitude)
else:
ref_value = np.abs(ref)

log_spec: np.ndarray = 10.0 * np.log10(np.maximum(amin, magnitude))
log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))

if top_db is not None:
if top_db < 0:
raise ValueError("top_db must be non-negative")
log_spec = np.maximum(log_spec, log_spec.max() - top_db)

return log_spec

# Extract the spectrogram
if mel_scale:
spectrogram = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"]
y_axis_label = "Mel Frequency"
else:
spectrogram = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"]
y_axis_label = "Frequency [Hz]"

if spectrogram.dim() != 2:
raise ValueError("Spectrogram must be a 2D tensor.")

# Determine time and frequency scale
num_frames = spectrogram.size(1)
num_freq_bins = spectrogram.size(0)

# Time axis in seconds
time_axis = (audio.waveform.size(-1) / audio.sampling_rate) * (torch.arange(0, num_frames).float() / num_frames)

# Frequency axis in Hz (for non-mel spectrograms)
if mel_scale:
freq_axis = torch.arange(num_freq_bins) # For mel spectrogram, keep the bins as discrete values
else:
freq_axis = torch.linspace(0, audio.sampling_rate / 2, num_freq_bins)

plt.figure(figsize=(10, 4))
plt.imshow(
spectrogram.numpy(),
_power_to_db(spectrogram.numpy()),
aspect="auto",
origin="lower",
extent=(0, spectrogram.size(1) / sample_rate, 0, sample_rate / 2),
extent=(float(time_axis[0]), float(time_axis[-1]), float(freq_axis[0]), float(freq_axis[-1])),
cmap="viridis",
)
plt.colorbar(label="Magnitude (dB)")
plt.title(title)
plt.ylabel("Frequency [Hz]")
plt.ylabel(y_axis_label)
plt.xlabel("Time [Sec]")
plt.show(block=False)

Expand Down
2 changes: 1 addition & 1 deletion src/tests/audio/tasks/features_extraction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_extract_spectrogram_from_audios(resampled_mono_audio_sample: Audio) ->
assert all(isinstance(spec["spectrogram"], torch.Tensor) for spec in result)
# Spectrogram shape is (freq, time)
assert all(spec["spectrogram"].dim() == 2 for spec in result)
assert all(spec["spectrogram"].shape[0] == 201 for spec in result)
assert all(spec["spectrogram"].shape[0] == 513 for spec in result)


def test_extract_mel_spectrogram_from_audios(resampled_mono_audio_sample: Audio) -> None:
Expand Down

0 comments on commit a1f7ea9

Please sign in to comment.