Skip to content

Commit

Permalink
Merge pull request #142 from sensein/141-task-add-audio-windowing-ite…
Browse files Browse the repository at this point in the history
…rator

Add audio windowing iterator to Audio
  • Loading branch information
fabiocat93 authored Sep 20, 2024
2 parents d926070 + 60f6bf3 commit 43e716f
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/senselab/audio/data_structures/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import os
import uuid
from typing import Dict, List, Optional, Tuple, Union
import warnings
from typing import Dict, Generator, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -102,6 +103,30 @@ def __eq__(self, other: object) -> bool:
return self.id() == other.id()
return False

def window_generator(self, window_size: int, step_size: int) -> Generator[torch.Tensor, None, None]:
"""Creates a sliding window generator for the audio waveform.
Args:
window_size: Size of each window (number of samples).
step_size: Step size for sliding the window (number of samples).
Raises:
ValueError: If step_size is greater than window_size.
"""
if step_size > window_size:
warnings.warn(
"Step size is greater than window size. \
Some of audio will not be included in the windows."
)

num_samples = self.waveform.size(-1)
current_position = 0

while current_position < num_samples - window_size:
window = self.waveform[:, current_position : current_position + window_size]
yield window
current_position += step_size


def batch_audios(audios: List[Audio]) -> Tuple[torch.Tensor, Union[int, List[int]], List[Dict]]:
"""Batches the Audios together into a single Tensor, keeping individual Audio information separate.
Expand Down
143 changes: 143 additions & 0 deletions src/tests/audio/data_structures/audio_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Module for testing Audio data structures."""

import warnings

import torch
import torchaudio

Expand Down Expand Up @@ -75,3 +77,144 @@ def test_audio_from_numpy(mono_audio_sample: Audio) -> None:
assert torch.equal(
mono_audio_sample.waveform, audio_from_numpy.waveform
), "NumPy audio should've been converted to Tensor"


def test_window_generator_overlap(mono_audio_sample: Audio) -> None:
"""Tests window generator with overlapping windows."""
window_size = 1024
step_size = 512
audio_length = mono_audio_sample.waveform.size(-1)

windows = list(mono_audio_sample.window_generator(window_size, step_size))

# Calculate expected windows
expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size is less than window size. Yielded {len(windows)}."


def test_window_generator_exact_fit(mono_audio_sample: Audio) -> None:
"""Tests window generator when step size equals window size."""
window_size = 1024
step_size = 1024
audio_length = mono_audio_sample.waveform.size(-1)

windows = list(mono_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
window when step size equals window size. Yielded {len(windows)}."


def test_window_generator_step_greater_than_window(mono_audio_sample: Audio) -> None:
"""Tests window generator when step size is greater than window size."""
window_size = 1024
step_size = 2048 # Step size greater than window size
audio_length = mono_audio_sample.waveform.size(-1)
mono_audio_sample.waveform = mono_audio_sample.waveform

windows = list(mono_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size is greater than window size. Yielded {len(windows)}."

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
list(mono_audio_sample.window_generator(window_size, step_size))
assert len(w) == 1, "Should issue a warning when step size is greater than window size."


def test_window_generator_overlap_stereo(stereo_audio_sample: Audio) -> None:
"""Tests window generator with overlapping windows for stereo audio."""
window_size = 1024
step_size = 512
audio_length = stereo_audio_sample.waveform.size(-1)

windows = list(stereo_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size is less than window size. Yielded {len(windows)}."


def test_window_generator_exact_fit_stereo(stereo_audio_sample: Audio) -> None:
"""Tests window generator when step size equals window size for stereo audio."""
window_size = 1024
step_size = 1024
audio_length = stereo_audio_sample.waveform.size(-1)

windows = list(stereo_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size equals window size. Yielded {len(windows)}."


def test_window_generator_step_greater_than_window_stereo(stereo_audio_sample: Audio) -> None:
"""Tests window generator when step size is greater than window size for stereo audio."""
window_size = 1
step_size = 2 # Step size greater than window size
audio_length = stereo_audio_sample.waveform.size(-1)
stereo_audio_sample.waveform = stereo_audio_sample.waveform

windows = list(stereo_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size is greater than window size. Yielded {len(windows)}."

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
list(stereo_audio_sample.window_generator(window_size, step_size))
assert len(w) == 1, "Should issue a warning when step size is greater than window size."


def test_window_generator_window_greater_than_audio_mono(mono_audio_sample: Audio) -> None:
"""Tests window generator when window size is greater than the audio length for mono audio."""
audio_length = mono_audio_sample.waveform.size(1)
window_size = audio_length + 1000 # Set window size greater than audio length
step_size = 512

windows = list(mono_audio_sample.window_generator(window_size, step_size))

assert len(windows) == 0, f"Should yield no windows when window size is greater \
than audio length. Yielded {len(windows)}."


def test_window_generator_window_greater_than_audio_stereo(stereo_audio_sample: Audio) -> None:
"""Tests window generator when window size is greater than the audio length for stereo audio."""
audio_length = stereo_audio_sample.waveform.size(1)
window_size = audio_length + 1000 # Set window size greater than audio length
step_size = 512

windows = list(stereo_audio_sample.window_generator(window_size, step_size))

assert len(windows) == 0, f"Should yield no windows when window size is \
greater than audio length. Yielded {len(windows)}."


def test_window_generator_step_greater_than_audio_mono(mono_audio_sample: Audio) -> None:
"""Tests window generator when step size is greater than the audio length for mono audio."""
audio_length = mono_audio_sample.waveform.size(1)
window_size = 1024
step_size = audio_length + 1000 # Step size greater than audio length

windows = list(mono_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size is greater than audio length. Yielded {len(windows)}."


def test_window_generator_step_greater_than_audio_stereo(stereo_audio_sample: Audio) -> None:
"""Tests window generator when step size is greater than the audio length for stereo audio."""
audio_length = stereo_audio_sample.waveform.size(1)
window_size = 1024
step_size = audio_length + 1000 # Step size greater than audio length

windows = list(stereo_audio_sample.window_generator(window_size, step_size))

expected_windows = (audio_length - window_size) // step_size + 1
assert len(windows) == expected_windows, f"Should yield {expected_windows} \
windows when step size is greater than audio length. Yielded {len(windows)}."

0 comments on commit 43e716f

Please sign in to comment.