Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding utility functions #48

Merged
merged 10 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 298 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ opensmile = "^2.5.0"
audiomentations = "^0.35.0"
torch-audiomentations = "^0.11.1"
sentence-transformers = "^2.7.0"
jiwer = "^3.0.4"
speechbrain = "^1.0.0"

[tool.poetry.group.dev]
optional = true
Expand Down Expand Up @@ -143,10 +145,10 @@ pattern = "default-unprefixed"

[tool.codespell]
skip = [
"./poetry.lock",
"./docs_style/pdoc-theme/syntax-highlighting.css"
"poetry.lock",
"docs_style/pdoc-theme/syntax-highlighting.css"
]
ignore-words-list = ["senselab"]
ignore-words-list = ["senselab", "nd", "astroid", "wil"]

[build-system]
requires = ["poetry-core>=1.0.0", "poetry-dynamic-versioning>=1.0.0,<2.0.0"]
Expand Down
4 changes: 2 additions & 2 deletions src/senselab/audio/tasks/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def augment_audio_dataset(audios: List[Audio], augmentation: Compose, batched: b
new_audios = []
if not batched:
for audio in audios:
audio_to_augment = audio.audio_data.unsqueeze(0)
audio_to_augment = audio.waveform.unsqueeze(0)
augmented_audio = augmentation(audio_to_augment, sample_rate=audio.sampling_rate).samples
new_audios.append(
Audio(
audio_data=torch.squeeze(augmented_audio),
waveform=torch.squeeze(augmented_audio),
sampling_rate=audio.sampling_rate,
metadata=audio.metadata.copy(),
)
Expand Down
93 changes: 84 additions & 9 deletions src/senselab/audio/tasks/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
"""This module implements some utilities for the preprocessing task."""

from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple

import torch
import torchaudio.functional as F
from datasets import Dataset

from senselab.utils.data_structures.audio import Audio
from senselab.utils.tasks.input_output import (
_from_dict_to_hf_dataset,
_from_hf_dataset_to_dict,
)
from senselab.utils.tasks.input_output import _from_dict_to_hf_dataset, _from_hf_dataset_to_dict


def resample_audio_dataset(audios: List[Audio], resample_rate: int, rolloff: float = 0.99) -> List[Audio]:
def resample_audios(audios: List[Audio], resample_rate: int, rolloff: float = 0.99) -> List[Audio]:
"""Resamples all Audios to a given sampling rate.

Takes a list of audios and resamples each into the new sampling rate. Notably does not assume any
Expand All @@ -30,13 +27,91 @@ def resample_audio_dataset(audios: List[Audio], resample_rate: int, rolloff: flo
"""
resampled_audios = []
for audio in audios:
resampled = F.resample(audio.audio_data, audio.sampling_rate, resample_rate, rolloff=rolloff)
resampled = F.resample(audio.waveform, audio.sampling_rate, resample_rate, rolloff=rolloff)
resampled_audios.append(
Audio(waveform=resampled, sampling_rate=resample_rate, metadata=audio.metadata, path_or_id=audio.path_or_id)
)
return resampled_audios


def downmix_audios_to_mono(audios: List[Audio]) -> List[Audio]:
"""Downmixes a list of Audio objects to mono by averaging all channels.

Args:
audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform.
Shape: (num_channels, num_samples).

Returns:
List[Audio]: The list of audio objects with a mono waveform averaged from all channels. Shape: (num_samples).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this gets into a later comment I made, but at least in the Audio class as of right now, we standardize the waveform field to be (num_channels, num_samples) and perhaps to keep consistency we keep it as (1, num_samples)?

"""
down_mixed_audios = []
for audio in audios:
down_mixed_audios.append(
Audio(
audio_data=resampled, sampling_rate=resample_rate, metadata=audio.metadata, path_or_id=audio.path_or_id
waveform=audio.waveform.mean(dim=0, keepdim=True),
sampling_rate=audio.sampling_rate,
metadata=audio.metadata.copy(),
)
)
return resampled_audios

return down_mixed_audios


def select_channel_from_audios(audios: List[Audio], channel_index: int) -> List[Audio]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically see all of my comments for the above function

"""Selects a specific channel from a list of Audio objects.

Args:
audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform.
Shape: (num_channels, num_samples).
channel_index (int): The index of the channel to select.

Returns:
List[Audio]: The list of audio objects with the selected channel. Shape: (num_channels, num_samples).
"""
mono_channel_audios = []
for audio in audios:
if audio.waveform.size(0) <= channel_index:
raise ValueError("channel_index should be valid")
mono_channel_audios.append(
Audio(
waveform=audio.waveform[channel_index, :],
sampling_rate=audio.sampling_rate,
metadata=audio.metadata.copy(),
)
)
return mono_channel_audios


def chunk_audios(data: List[Tuple[Audio, Tuple[float, float]]]) -> List[Audio]:
"""Chunks the input audios based on the start and end timestamp.

Args:
data: List of tuples containing an Audio object and a tuple with start and end (in seconds) for chunking.

Returns:
List of Audios that have been chunked based on the provided timestamps
"""
chunked_audios = []

for audio, timestamps in data:
start, end = timestamps
if start < 0:
raise ValueError("Start time must be greater than or equal to 0.")
duration = audio.waveform.shape[1] / audio.sampling_rate
if end > duration:
raise ValueError(f"End time must be less than the duration of the audio file ({duration} seconds).")
start_sample = int(start * audio.sampling_rate)
end_sample = int(end * audio.sampling_rate)
chunked_waveform = audio.waveform[:, start_sample:end_sample]
chunked_audios.append(
Audio(
waveform=chunked_waveform,
sampling_rate=audio.sampling_rate,
metadata=audio.metadata,
path_or_id=f"{audio.path_or_id}_chunk_{start}_{end}", # TODO: Fix this
)
)
return chunked_audios


def resample_hf_dataset(dataset: Dict[str, Any], resample_rate: int, rolloff: float = 0.99) -> Dict[str, Any]:
Expand Down
10 changes: 8 additions & 2 deletions src/senselab/audio/tasks/preprocessing_pydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
import pydra

from senselab.audio.tasks.preprocessing import (
resample_audio_dataset,
chunk_audios,
downmix_audios_to_mono,
resample_audios,
resample_hf_dataset,
select_channel_from_audios,
)

resample_audios_pt = pydra.mark.task(resample_audios)
downmix_audios_to_mono_pt = pydra.mark.task(downmix_audios_to_mono)
chunk_audios_pt = pydra.mark.task(chunk_audios)
resample_hf_dataset_pt = pydra.mark.task(resample_hf_dataset)
resample_audio_dataset_pt = pydra.mark.task(resample_audio_dataset)
select_channel_from_audios_pt = pydra.mark.task(select_channel_from_audios)
88 changes: 88 additions & 0 deletions src/senselab/audio/tasks/speech_to_text_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""This module implements some utilities for evaluating a transcription."""

import jiwer


def calculate_wer(reference: str, hypothesis: str) -> float:
"""Calculate the Word Error Rate (WER) between the reference and hypothesis.

Args:
reference (str): The ground truth text.
hypothesis (str): The predicted text.

Returns:
float: The WER score.

Examples:
>>> calculate_wer("hello world", "hello duck")
0.5
"""
return jiwer.wer(reference, hypothesis)


def calculate_mer(reference: str, hypothesis: str) -> float:
"""Calculate the Match Error Rate (MER) between the reference and hypothesis.

Args:
reference (str): The ground truth text.
hypothesis (str): The predicted text.

Returns:
float: The MER score.

Examples:
>>> calculate_mer("hello world", "hello duck")
0.5
"""
return jiwer.mer(reference, hypothesis)


def calculate_wil(reference: str, hypothesis: str) -> float:
"""Calculate the Word Information Lost (WIL) between the reference and hypothesis.

Args:
reference (str): The ground truth text.
hypothesis (str): The predicted text.

Returns:
float: The WIL score.

Examples:
>>> calculate_wil("hello world", "hello duck")
0.75
"""
return jiwer.wil(reference, hypothesis)


def calculate_wip(reference: str, hypothesis: str) -> float:
"""Calculate the Word Information Preserved (WIP) between the reference and hypothesis.

Args:
reference (str): The ground truth text.
hypothesis (str): The predicted text.

Returns:
float: The WIP score.

Examples:
>>> calculate_wip("hello world", "hello duck")
0.25
"""
return jiwer.wip(reference, hypothesis)


def calculate_cer(reference: str, hypothesis: str) -> float:
"""Calculate the Character Error Rate (CER) between the reference and hypothesis.

Args:
reference (str): The ground truth text.
hypothesis (str): The predicted text.

Returns:
float: The CER score.

Examples:
>>> calculate_cer("hello world", "hello duck")
0.45454545454545453
"""
return jiwer.cer(reference, hypothesis)
17 changes: 17 additions & 0 deletions src/senselab/audio/tasks/speech_to_text_evaluation_pydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""This module defines a pydra API for the speech to text evaluation task."""

import pydra

from senselab.audio.tasks.speech_to_text_evaluation import (
calculate_cer,
calculate_mer,
calculate_wer,
calculate_wil,
calculate_wip,
)

calculate_wer_pt = pydra.mark.task(calculate_wer)
calculate_mer_pt = pydra.mark.task(calculate_mer)
calculate_wil_pt = pydra.mark.task(calculate_wil)
calculate_wip_pt = pydra.mark.task(calculate_wip)
calculate_cer_pt = pydra.mark.task(calculate_cer)
16 changes: 8 additions & 8 deletions src/senselab/utils/data_structures/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Audio(BaseModel):
and has a unique identifier for every audio.

Attributes:
audio_data: The actual audio data read from an audio file, stored as a torch.Tensor
waveform: The actual audio data read from an audio file, stored as a torch.Tensor
of shape (num_channels, num_samples)
sampling_rate: The sampling rate of the audio file
path_or_id: A unique identifier for the audio, defined either by the path the audio was
Expand All @@ -33,13 +33,13 @@ class Audio(BaseModel):
(e.g. participant demographics, audio settings, location information)
"""

audio_data: torch.Tensor
waveform: torch.Tensor
sampling_rate: int
path_or_id: Optional[str] = Field(default_factory=lambda: str(uuid.uuid4()))
metadata: Dict = Field(default={})
model_config = {"arbitrary_types_allowed": True}

@field_validator("audio_data", mode="before")
@field_validator("waveform", mode="before")
def convert_to_tensor(
cls, v: Union[List[float], List[List[float]], np.ndarray, torch.Tensor], info: ValidationInfo
) -> torch.Tensor:
Expand Down Expand Up @@ -69,13 +69,13 @@ def from_filepath(cls, filepath: str, metadata: Dict = {}) -> "Audio":
"""
array, sampling_rate = torchaudio.load(filepath)

return cls(audio_data=array, sampling_rate=sampling_rate, path_or_id=filepath, metadata=metadata)
return cls(waveform=array, sampling_rate=sampling_rate, path_or_id=filepath, metadata=metadata)

def __eq__(self, other: object) -> bool:
"""Overloads the default BaseModel equality to correctly check that torch.Tensors are equivalent."""
if isinstance(other, Audio):
return (
torch.equal(self.audio_data, other.audio_data)
torch.equal(self.waveform, other.waveform)
and self.sampling_rate == other.sampling_rate
and self.metadata == other.metadata
and self.path_or_id == other.path_or_id
Expand Down Expand Up @@ -108,7 +108,7 @@ def batch_audios(audios: List[Audio]) -> Tuple[torch.Tensor, Union[int, List[int
metadatas = []
for audio in audios:
sampling_rates.append(audio.sampling_rate)
batched_audio.append(audio.audio_data)
batched_audio.append(audio.waveform)
metadatas.append(audio.metadata)

return_sampling_rates: List[int] | int = int(sampling_rates[0]) if len(set(sampling_rates)) == 1 else sampling_rates
Expand Down Expand Up @@ -149,7 +149,7 @@ def unbatch_audios(batched_audio: torch.Tensor, sampling_rates: int | List[int],
sampling_rate = sampling_rates[i] if isinstance(sampling_rates, List) else sampling_rates
metadata = metadatas[i]
audio = batched_audio[i]
audios.append(Audio(audio_data=audio, sampling_rate=sampling_rate, metadata=metadata))
audios.append(Audio(waveform=audio, sampling_rate=sampling_rate, metadata=metadata))
return audios


Expand Down Expand Up @@ -257,7 +257,7 @@ def generate_dataset_from_audio_data(
sampling_rate = sampling_rates[i] if isinstance(sampling_rates, List) else sampling_rates
audio_path_or_id = audio_paths_or_ids[i] if audio_paths_or_ids else None
audio = Audio(
audio_data=audios_data[i],
waveform=audios_data[i],
sampling_rate=sampling_rate,
path_or_id=audio_path_or_id,
metadata=audio_metadata,
Expand Down
Loading
Loading