From ca78cd95ad76705c027a4f49bcb511d66aef750d Mon Sep 17 00:00:00 2001 From: fabiocat93 Date: Mon, 30 Sep 2024 23:47:57 -0400 Subject: [PATCH 1/3] improving plotting utilities --- .../tasks/features_extraction/torchaudio.py | 8 +- src/senselab/audio/tasks/plotting/__init__.py | 3 + src/senselab/audio/tasks/plotting/plotting.py | 90 +++++++++++++++++-- 3 files changed, 91 insertions(+), 10 deletions(-) create mode 100644 src/senselab/audio/tasks/plotting/__init__.py diff --git a/src/senselab/audio/tasks/features_extraction/torchaudio.py b/src/senselab/audio/tasks/features_extraction/torchaudio.py index f1857d2..f186378 100644 --- a/src/senselab/audio/tasks/features_extraction/torchaudio.py +++ b/src/senselab/audio/tasks/features_extraction/torchaudio.py @@ -10,7 +10,7 @@ def extract_spectrogram_from_audios( audios: List[Audio], - n_fft: int = 400, + n_fft: int = 1024, win_length: Optional[int] = None, hop_length: Optional[int] = None, ) -> List[Dict[str, torch.Tensor]]: @@ -18,7 +18,7 @@ def extract_spectrogram_from_audios( Args: audios (List[Audio]): List of Audio objects. - n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 400. + n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 1024. win_length (int): Window size. Default is None, using n_fft. hop_length (int): Length of hop between STFT windows. Default is None, using win_length // 2. @@ -42,7 +42,7 @@ def extract_spectrogram_from_audios( def extract_mel_spectrogram_from_audios( audios: List[Audio], - n_fft: Optional[int] = 400, + n_fft: Optional[int] = 1024, win_length: Optional[int] = None, hop_length: Optional[int] = None, n_mels: int = 128, @@ -51,7 +51,7 @@ def extract_mel_spectrogram_from_audios( Args: audios (List[Audio]): List of Audio objects. - n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 400. + n_fft (int): Size of FFT, creates n_fft // 2 + 1 bins. Default is 1024. win_length (int): Window size. Default is None, using n_fft. hop_length (int): Length of hop between STFT windows. Default is None, using win_length // 2. n_mels (int): Number of mel filter banks. Default is 128. diff --git a/src/senselab/audio/tasks/plotting/__init__.py b/src/senselab/audio/tasks/plotting/__init__.py new file mode 100644 index 0000000..47a1e7d --- /dev/null +++ b/src/senselab/audio/tasks/plotting/__init__.py @@ -0,0 +1,3 @@ +"""This module contains functions for plotting audio-related data.""" + +from .plotting import play_audio, plot_specgram, plot_waveform # noqa: F401 diff --git a/src/senselab/audio/tasks/plotting/plotting.py b/src/senselab/audio/tasks/plotting/plotting.py index 68ff437..1755d2b 100644 --- a/src/senselab/audio/tasks/plotting/plotting.py +++ b/src/senselab/audio/tasks/plotting/plotting.py @@ -1,10 +1,18 @@ """This module contains functions for plotting audio-related data.""" +from typing import Any + import matplotlib.pyplot as plt +import numpy as np import torch from IPython.display import Audio, display from senselab.audio.data_structures import Audio as AudioData +from senselab.audio.tasks.features_extraction.torchaudio import ( + extract_mel_spectrogram_from_audios, + extract_spectrogram_from_audios, +) +from senselab.utils.data_structures import logger def plot_waveform(audio: AudioData, title: str = "Waveform", fast: bool = False) -> None: @@ -42,31 +50,100 @@ def plot_waveform(audio: AudioData, title: str = "Waveform", fast: bool = False) plt.show(block=False) -def plot_specgram(spectrogram: torch.Tensor, sample_rate: int, title: str = "Spectrogram") -> None: +def plot_specgram(audio: AudioData, mel_scale: bool = False, title: str = "Spectrogram", **spect_kwargs: Any) -> None: # noqa : ANN401 """Plots the spectrogram of an Audio object. Args: - spectrogram (torch.Tensor): A tensor representing the spectrogram. - sample_rate (int): The sampling rate of the audio data. + audio: (AudioData): An instance of Audio containing waveform data and sampling rate. + mel_scale (bool): Whether to plot a mel spectrogram or a regular spectrogram. title (str): Title of the spectrogram plot. + **spect_kwargs: Additional keyword arguments to pass to the spectrogram function. Todo: - Add option to save the plot - Add option to choose the size of the Figure """ + def _power_to_db( + spectrogram: np.ndarray, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0 + ) -> np.ndarray: + """Converts a power spectrogram (amplitude squared) to decibel (dB) units. + + Args: + spectrogram (np.ndarray): Power spectrogram to convert. + ref (float): Reference power. Default is 1.0. + amin (float): Minimum power. Default is 1e-10. + top_db (float): Minimum decibel. Default is 80.0. + + Returns: + np.ndarray: Decibel spectrogram. + """ + S = np.asarray(spectrogram) + + if amin <= 0: + raise ValueError("amin must be strictly positive") + + if np.issubdtype(S.dtype, np.complexfloating): + logger.warning( + "_power_to_db was called on complex input so phase " + "information will be discarded. To suppress this warning, " + "call power_to_db(np.abs(D)**2) instead.", + stacklevel=2, + ) + magnitude = np.abs(S) + else: + magnitude = S + + if callable(ref): + # User supplied a function to calculate reference power + ref_value = ref(magnitude) + else: + ref_value = np.abs(ref) + + log_spec: np.ndarray = 10.0 * np.log10(np.maximum(amin, magnitude)) + log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + + if top_db is not None: + if top_db < 0: + raise ValueError("top_db must be non-negative") + log_spec = np.maximum(log_spec, log_spec.max() - top_db) + + return log_spec + + # Extract the spectrogram + if mel_scale: + spectrogram = extract_mel_spectrogram_from_audios([audio], **spect_kwargs)[0]["mel_spectrogram"] + y_axis_label = "Mel Frequency" + else: + spectrogram = extract_spectrogram_from_audios([audio], **spect_kwargs)[0]["spectrogram"] + y_axis_label = "Frequency [Hz]" + if spectrogram.dim() != 2: raise ValueError("Spectrogram must be a 2D tensor.") + + # Determine time and frequency scale + num_frames = spectrogram.size(1) + num_freq_bins = spectrogram.size(0) + + # Time axis in seconds + time_axis = (audio.waveform.size(-1) / audio.sampling_rate) * (torch.arange(0, num_frames).float() / num_frames) + + # Frequency axis in Hz (for non-mel spectrograms) + if mel_scale: + freq_axis = torch.arange(num_freq_bins) # For mel spectrogram, keep the bins as discrete values + else: + freq_axis = torch.linspace(0, audio.sampling_rate / 2, num_freq_bins) + plt.figure(figsize=(10, 4)) plt.imshow( - spectrogram.numpy(), + _power_to_db(spectrogram.numpy()), aspect="auto", origin="lower", - extent=(0, spectrogram.size(1) / sample_rate, 0, sample_rate / 2), + extent=(float(time_axis[0]), float(time_axis[-1]), float(freq_axis[0]), float(freq_axis[-1])), cmap="viridis", ) plt.colorbar(label="Magnitude (dB)") plt.title(title) - plt.ylabel("Frequency [Hz]") + plt.ylabel(y_axis_label) plt.xlabel("Time [Sec]") plt.show(block=False) @@ -90,3 +167,4 @@ def play_audio(audio: AudioData) -> None: display(Audio((waveform[0], waveform[1]), rate=sample_rate)) else: raise ValueError("Waveform with more than 2 channels are not supported.") + \ No newline at end of file From 5160191f2396d58dd37d2014d3f36b5f71acf96e Mon Sep 17 00:00:00 2001 From: fabiocat93 Date: Mon, 30 Sep 2024 23:48:54 -0400 Subject: [PATCH 2/3] fixing style issues --- src/senselab/audio/tasks/plotting/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/senselab/audio/tasks/plotting/plotting.py b/src/senselab/audio/tasks/plotting/plotting.py index 1755d2b..5b2511b 100644 --- a/src/senselab/audio/tasks/plotting/plotting.py +++ b/src/senselab/audio/tasks/plotting/plotting.py @@ -63,6 +63,7 @@ def plot_specgram(audio: AudioData, mel_scale: bool = False, title: str = "Spect - Add option to save the plot - Add option to choose the size of the Figure """ + def _power_to_db( spectrogram: np.ndarray, ref: float = 1.0, amin: float = 1e-10, top_db: float = 80.0 ) -> np.ndarray: @@ -167,4 +168,3 @@ def play_audio(audio: AudioData) -> None: display(Audio((waveform[0], waveform[1]), rate=sample_rate)) else: raise ValueError("Waveform with more than 2 channels are not supported.") - \ No newline at end of file From dbc68cdf20a4e31ee07447cdf2ca91d63a331b6d Mon Sep 17 00:00:00 2001 From: fabiocat93 Date: Mon, 30 Sep 2024 23:54:07 -0400 Subject: [PATCH 3/3] adjusting tests --- src/tests/audio/tasks/features_extraction_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/audio/tasks/features_extraction_test.py b/src/tests/audio/tasks/features_extraction_test.py index 2eb65c8..2b6e347 100644 --- a/src/tests/audio/tasks/features_extraction_test.py +++ b/src/tests/audio/tasks/features_extraction_test.py @@ -34,7 +34,7 @@ def test_extract_spectrogram_from_audios(resampled_mono_audio_sample: Audio) -> assert all(isinstance(spec["spectrogram"], torch.Tensor) for spec in result) # Spectrogram shape is (freq, time) assert all(spec["spectrogram"].dim() == 2 for spec in result) - assert all(spec["spectrogram"].shape[0] == 201 for spec in result) + assert all(spec["spectrogram"].shape[0] == 513 for spec in result) def test_extract_mel_spectrogram_from_audios(resampled_mono_audio_sample: Audio) -> None: