generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding utility functions #48
Merged
Merged
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9993456
adding speech to text evaluation task
fabiocat93 62a88aa
adding cca and cka functions
fabiocat93 97e6102
adding cosine similarity function
fabiocat93 31466dd
adding cross correlation
fabiocat93 4790680
adding eer function
fabiocat93 710e2d1
fixing spell issue
fabiocat93 0291061
fixing typing issue
fabiocat93 a74517e
adding preprocessing functions
fabiocat93 66ff00b
treating cka kernels with enum
fabiocat93 3ba1620
fixing style issues
fabiocat93 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,16 @@ | ||
"""This module implements some utilities for the preprocessing task.""" | ||
|
||
from typing import Any, Dict, List | ||
from typing import Any, Dict, List, Tuple | ||
|
||
import torch | ||
import torchaudio.functional as F | ||
from datasets import Dataset | ||
|
||
from senselab.utils.data_structures.audio import Audio | ||
from senselab.utils.tasks.input_output import ( | ||
_from_dict_to_hf_dataset, | ||
_from_hf_dataset_to_dict, | ||
) | ||
from senselab.utils.tasks.input_output import _from_dict_to_hf_dataset, _from_hf_dataset_to_dict | ||
|
||
|
||
def resample_audio_dataset(audios: List[Audio], resample_rate: int, rolloff: float = 0.99) -> List[Audio]: | ||
def resample_audios(audios: List[Audio], resample_rate: int, rolloff: float = 0.99) -> List[Audio]: | ||
"""Resamples all Audios to a given sampling rate. | ||
|
||
Takes a list of audios and resamples each into the new sampling rate. Notably does not assume any | ||
|
@@ -30,13 +27,91 @@ def resample_audio_dataset(audios: List[Audio], resample_rate: int, rolloff: flo | |
""" | ||
resampled_audios = [] | ||
for audio in audios: | ||
resampled = F.resample(audio.audio_data, audio.sampling_rate, resample_rate, rolloff=rolloff) | ||
resampled = F.resample(audio.waveform, audio.sampling_rate, resample_rate, rolloff=rolloff) | ||
resampled_audios.append( | ||
Audio(waveform=resampled, sampling_rate=resample_rate, metadata=audio.metadata, path_or_id=audio.path_or_id) | ||
) | ||
return resampled_audios | ||
|
||
|
||
def downmix_audios_to_mono(audios: List[Audio]) -> List[Audio]: | ||
"""Downmixes a list of Audio objects to mono by averaging all channels. | ||
|
||
Args: | ||
audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform. | ||
Shape: (num_channels, num_samples). | ||
|
||
Returns: | ||
List[Audio]: The list of audio objects with a mono waveform averaged from all channels. Shape: (num_samples). | ||
""" | ||
down_mixed_audios = [] | ||
for audio in audios: | ||
down_mixed_audios.append( | ||
Audio( | ||
audio_data=resampled, sampling_rate=resample_rate, metadata=audio.metadata, path_or_id=audio.path_or_id | ||
waveform=audio.waveform.mean(dim=0, keepdim=True), | ||
sampling_rate=audio.sampling_rate, | ||
metadata=audio.metadata.copy(), | ||
) | ||
) | ||
return resampled_audios | ||
|
||
return down_mixed_audios | ||
|
||
|
||
def select_channel_from_audios(audios: List[Audio], channel_index: int) -> List[Audio]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. basically see all of my comments for the above function |
||
"""Selects a specific channel from a list of Audio objects. | ||
|
||
Args: | ||
audios (List[Audio]): A list of Audio objects with a tensor representing the audio waveform. | ||
Shape: (num_channels, num_samples). | ||
channel_index (int): The index of the channel to select. | ||
|
||
Returns: | ||
List[Audio]: The list of audio objects with the selected channel. Shape: (num_channels, num_samples). | ||
""" | ||
mono_channel_audios = [] | ||
for audio in audios: | ||
if audio.waveform.size(0) <= channel_index: | ||
raise ValueError("channel_index should be valid") | ||
mono_channel_audios.append( | ||
Audio( | ||
waveform=audio.waveform[channel_index, :], | ||
sampling_rate=audio.sampling_rate, | ||
metadata=audio.metadata.copy(), | ||
) | ||
) | ||
return mono_channel_audios | ||
|
||
|
||
def chunk_audios(data: List[Tuple[Audio, Tuple[float, float]]]) -> List[Audio]: | ||
"""Chunks the input audios based on the start and end timestamp. | ||
|
||
Args: | ||
data: List of tuples containing an Audio object and a tuple with start and end (in seconds) for chunking. | ||
|
||
Returns: | ||
List of Audios that have been chunked based on the provided timestamps | ||
""" | ||
chunked_audios = [] | ||
|
||
for audio, timestamps in data: | ||
start, end = timestamps | ||
if start < 0: | ||
raise ValueError("Start time must be greater than or equal to 0.") | ||
duration = audio.waveform.shape[1] / audio.sampling_rate | ||
if end > duration: | ||
raise ValueError(f"End time must be less than the duration of the audio file ({duration} seconds).") | ||
start_sample = int(start * audio.sampling_rate) | ||
end_sample = int(end * audio.sampling_rate) | ||
chunked_waveform = audio.waveform[:, start_sample:end_sample] | ||
chunked_audios.append( | ||
Audio( | ||
waveform=chunked_waveform, | ||
sampling_rate=audio.sampling_rate, | ||
metadata=audio.metadata, | ||
path_or_id=f"{audio.path_or_id}_chunk_{start}_{end}", # TODO: Fix this | ||
) | ||
) | ||
return chunked_audios | ||
|
||
|
||
def resample_hf_dataset(dataset: Dict[str, Any], resample_rate: int, rolloff: float = 0.99) -> Dict[str, Any]: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""This module implements some utilities for evaluating a transcription.""" | ||
|
||
import jiwer | ||
|
||
|
||
def calculate_wer(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Word Error Rate (WER) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The WER score. | ||
|
||
Examples: | ||
>>> calculate_wer("hello world", "hello duck") | ||
0.5 | ||
""" | ||
return jiwer.wer(reference, hypothesis) | ||
|
||
|
||
def calculate_mer(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Match Error Rate (MER) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The MER score. | ||
|
||
Examples: | ||
>>> calculate_mer("hello world", "hello duck") | ||
0.5 | ||
""" | ||
return jiwer.mer(reference, hypothesis) | ||
|
||
|
||
def calculate_wil(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Word Information Lost (WIL) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The WIL score. | ||
|
||
Examples: | ||
>>> calculate_wil("hello world", "hello duck") | ||
0.75 | ||
""" | ||
return jiwer.wil(reference, hypothesis) | ||
|
||
|
||
def calculate_wip(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Word Information Preserved (WIP) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The WIP score. | ||
|
||
Examples: | ||
>>> calculate_wip("hello world", "hello duck") | ||
0.25 | ||
""" | ||
return jiwer.wip(reference, hypothesis) | ||
|
||
|
||
def calculate_cer(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Character Error Rate (CER) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The CER score. | ||
|
||
Examples: | ||
>>> calculate_cer("hello world", "hello duck") | ||
0.45454545454545453 | ||
""" | ||
return jiwer.cer(reference, hypothesis) |
17 changes: 17 additions & 0 deletions
17
src/senselab/audio/tasks/speech_to_text_evaluation_pydra.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""This module defines a pydra API for the speech to text evaluation task.""" | ||
|
||
import pydra | ||
|
||
from senselab.audio.tasks.speech_to_text_evaluation import ( | ||
calculate_cer, | ||
calculate_mer, | ||
calculate_wer, | ||
calculate_wil, | ||
calculate_wip, | ||
) | ||
|
||
calculate_wer_pt = pydra.mark.task(calculate_wer) | ||
calculate_mer_pt = pydra.mark.task(calculate_mer) | ||
calculate_wil_pt = pydra.mark.task(calculate_wil) | ||
calculate_wip_pt = pydra.mark.task(calculate_wip) | ||
calculate_cer_pt = pydra.mark.task(calculate_cer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this gets into a later comment I made, but at least in the Audio class as of right now, we standardize the waveform field to be (num_channels, num_samples) and perhaps to keep consistency we keep it as (1, num_samples)?