generated from sensein/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding utility functions #48
Merged
Merged
Changes from 7 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9993456
adding speech to text evaluation task
fabiocat93 62a88aa
adding cca and cka functions
fabiocat93 97e6102
adding cosine similarity function
fabiocat93 31466dd
adding cross correlation
fabiocat93 4790680
adding eer function
fabiocat93 710e2d1
fixing spell issue
fabiocat93 0291061
fixing typing issue
fabiocat93 a74517e
adding preprocessing functions
fabiocat93 66ff00b
treating cka kernels with enum
fabiocat93 3ba1620
fixing style issues
fabiocat93 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""This module implements some utilities for evaluating a transcription.""" | ||
|
||
import jiwer | ||
|
||
|
||
def calculate_wer(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Word Error Rate (WER) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The WER score. | ||
|
||
Examples: | ||
>>> calculate_wer("hello world", "hello duck") | ||
0.5 | ||
""" | ||
return jiwer.wer(reference, hypothesis) | ||
|
||
|
||
def calculate_mer(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Match Error Rate (MER) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The MER score. | ||
|
||
Examples: | ||
>>> calculate_mer("hello world", "hello duck") | ||
0.5 | ||
""" | ||
return jiwer.mer(reference, hypothesis) | ||
|
||
|
||
def calculate_wil(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Word Information Lost (WIL) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The WIL score. | ||
|
||
Examples: | ||
>>> calculate_wil("hello world", "hello duck") | ||
0.75 | ||
""" | ||
return jiwer.wil(reference, hypothesis) | ||
|
||
|
||
def calculate_wip(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Word Information Preserved (WIP) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The WIP score. | ||
|
||
Examples: | ||
>>> calculate_wip("hello world", "hello duck") | ||
0.25 | ||
""" | ||
return jiwer.wip(reference, hypothesis) | ||
|
||
|
||
def calculate_cer(reference: str, hypothesis: str) -> float: | ||
"""Calculate the Character Error Rate (CER) between the reference and hypothesis. | ||
|
||
Args: | ||
reference (str): The ground truth text. | ||
hypothesis (str): The predicted text. | ||
|
||
Returns: | ||
float: The CER score. | ||
|
||
Examples: | ||
>>> calculate_cer("hello world", "hello duck") | ||
0.45454545454545453 | ||
""" | ||
return jiwer.cer(reference, hypothesis) |
17 changes: 17 additions & 0 deletions
17
src/senselab/audio/tasks/speech_to_text_evaluation_pydra.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""This module defines a pydra API for the speech to text evaluation task.""" | ||
|
||
import pydra | ||
|
||
from senselab.audio.tasks.speech_to_text_evaluation import ( | ||
calculate_cer, | ||
calculate_mer, | ||
calculate_wer, | ||
calculate_wil, | ||
calculate_wip, | ||
) | ||
|
||
calculate_wer_pt = pydra.mark.task(calculate_wer) | ||
calculate_mer_pt = pydra.mark.task(calculate_mer) | ||
calculate_wil_pt = pydra.mark.task(calculate_wil) | ||
calculate_wip_pt = pydra.mark.task(calculate_wip) | ||
calculate_cer_pt = pydra.mark.task(calculate_cer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""This module is for computing CCA and CKA.""" | ||
|
||
import torch | ||
|
||
|
||
def compute_cca(features_x: torch.Tensor, features_y: torch.Tensor) -> float: | ||
"""Compute the mean squared CCA correlation (R^2_{CCA}). | ||
|
||
Args: | ||
features_x (torch.Tensor): A num_examples x num_features matrix of features. | ||
features_y (torch.Tensor): A num_examples x num_features matrix of features. | ||
|
||
Returns: | ||
float: The mean squared CCA correlations between X and Y. | ||
""" | ||
qx, _ = torch.qr(features_x) | ||
qy, _ = torch.qr(features_y) | ||
result = torch.norm(qx.t() @ qy) ** 2 / min(features_x.shape[1], features_y.shape[1]) | ||
return result.item() if isinstance(result, torch.Tensor) else float(result) | ||
|
||
|
||
def compute_cka( | ||
features_x: torch.Tensor, features_y: torch.Tensor, kernel: str = "linear", threshold: float = 1.0 | ||
) -> float: | ||
"""Compute CKA between feature matrices. | ||
|
||
Args: | ||
features_x (torch.Tensor): A num_examples x num_features matrix of features. | ||
features_y (torch.Tensor): A num_examples x num_features matrix of features. | ||
kernel (str): Type of kernel to use ('linear' or 'rbf'). Default is 'linear'. | ||
threshold (float): Fraction of median Euclidean distance to use as RBF kernel bandwidth | ||
(used only if kernel is 'rbf'). | ||
|
||
Returns: | ||
float: The value of CKA between X and Y. | ||
""" | ||
|
||
def _gram_linear(x: torch.Tensor) -> torch.Tensor: | ||
"""Compute Gram (kernel) matrix for a linear kernel. | ||
|
||
Args: | ||
x (torch.Tensor): A num_examples x num_features matrix of features. | ||
|
||
Returns: | ||
torch.Tensor: A num_examples x num_examples Gram matrix of examples. | ||
""" | ||
return x @ x.t() | ||
|
||
def _gram_rbf(x: torch.Tensor, threshold: float = 1.0) -> torch.Tensor: | ||
"""Compute Gram (kernel) matrix for an RBF kernel. | ||
|
||
Args: | ||
x (torch.Tensor): A num_examples x num_features matrix of features. | ||
threshold (float): Fraction of median Euclidean distance to use as RBF kernel bandwidth. | ||
|
||
Returns: | ||
torch.Tensor: A num_examples x num_examples Gram matrix of examples. | ||
""" | ||
dot_products = x @ x.t() | ||
sq_norms = torch.diag(dot_products) | ||
sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :] | ||
sq_median_distance = torch.median(sq_distances) | ||
return torch.exp(-sq_distances / (2 * threshold**2 * sq_median_distance)) | ||
|
||
def _center_gram(gram: torch.Tensor) -> torch.Tensor: | ||
"""Center a symmetric Gram matrix. | ||
|
||
This is equivalent to centering the (possibly infinite-dimensional) features | ||
induced by the kernel before computing the Gram matrix. | ||
|
||
Args: | ||
gram (torch.Tensor): A num_examples x num_examples symmetric matrix. | ||
|
||
Returns: | ||
torch.Tensor: A symmetric matrix with centered columns and rows. | ||
|
||
Raises: | ||
ValueError: If the input is not a symmetric matrix. | ||
""" | ||
if not torch.allclose(gram, gram.t()): | ||
raise ValueError("Input must be a symmetric matrix.") | ||
|
||
n = gram.size(0) | ||
unit = torch.ones(n, n, device=gram.device) | ||
eye = torch.eye(n, device=gram.device) | ||
unit = unit / n | ||
haitch = eye - unit | ||
centered_gram = haitch.mm(gram).mm(haitch) | ||
return centered_gram | ||
|
||
def _cka(gram_x: torch.Tensor, gram_y: torch.Tensor) -> torch.Tensor: | ||
"""Compute CKA. | ||
|
||
Args: | ||
gram_x (torch.Tensor): A num_examples x num_examples Gram matrix. | ||
gram_y (torch.Tensor): A num_examples x num_examples Gram matrix. | ||
|
||
Returns: | ||
float: The value of CKA between X and Y. | ||
""" | ||
gram_x = _center_gram(gram_x) | ||
gram_y = _center_gram(gram_y) | ||
|
||
scaled_hsic = torch.sum(gram_x * gram_y) | ||
|
||
normalization_x = torch.norm(gram_x) | ||
normalization_y = torch.norm(gram_y) | ||
return scaled_hsic / (normalization_x * normalization_y) | ||
|
||
if kernel == "linear": | ||
gram_x = _gram_linear(features_x) | ||
gram_y = _gram_linear(features_y) | ||
elif kernel == "rbf": | ||
gram_x = _gram_rbf(features_x, threshold) | ||
gram_y = _gram_rbf(features_y, threshold) | ||
else: | ||
raise ValueError("Unsupported kernel type. Use 'linear' or 'rbf'.") | ||
|
||
result = _cka(gram_x, gram_y) | ||
return result.item() if isinstance(result, torch.Tensor) else float(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""This module defines a pydra API for the CCA and CKA tasks.""" | ||
|
||
import pydra | ||
|
||
from senselab.utils.tasks.cca_cka import compute_cca, compute_cka | ||
|
||
compute_cca_pt = pydra.mark.task(compute_cca) | ||
compute_cka_pt = pydra.mark.task(compute_cka) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""This module provides the implementation of cosine similarity.""" | ||
|
||
import torch | ||
|
||
|
||
def compute_cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor) -> float: | ||
"""Compute the cosine similarity between two torch tensors. | ||
|
||
Args: | ||
tensor1 (Tensor): The first input tensor. | ||
tensor2 (Tensor): The second input tensor. | ||
|
||
Returns: | ||
float: The cosine similarity between the two input tensors. | ||
|
||
Raises: | ||
ValueError: If the input tensors are not of the same shape. | ||
|
||
Examples: | ||
>>> tensor1 = torch.tensor([1.0, 2.0, 3.0]) | ||
>>> tensor2 = torch.tensor([4.0, 5.0, 6.0]) | ||
>>> cosine_similarity(tensor1, tensor2) | ||
0.9746318461970762 | ||
|
||
>>> tensor1 = torch.tensor([1.0, 0.0, -1.0]) | ||
>>> tensor2 = torch.tensor([-1.0, 0.0, 1.0]) | ||
>>> cosine_similarity(tensor1, tensor2) | ||
-1.0 | ||
|
||
Note: | ||
This function assumes the input tensors are 1-dimensional and have the same shape. | ||
""" | ||
if tensor1.dim() != 1 or tensor2.dim() != 1: | ||
raise ValueError("Input tensors must be 1-dimensional") | ||
if tensor1.shape != tensor2.shape: | ||
raise ValueError("Input tensors must have the same shape") | ||
|
||
dot_product = torch.dot(tensor1, tensor2) | ||
norm_tensor1 = torch.norm(tensor1) | ||
norm_tensor2 = torch.norm(tensor2) | ||
|
||
cosine_sim = dot_product / (norm_tensor1 * norm_tensor2) | ||
return cosine_sim.item() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""This module defines a pydra API for computing cosine similarity.""" | ||
|
||
import pydra | ||
|
||
from senselab.utils.tasks.cosine_similarity import compute_cosine_similarity | ||
|
||
cosine_similarity_pt = pydra.mark.task(compute_cosine_similarity) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
"""This module contains functions for computing the normalized cross-correlation between two signals.""" | ||
|
||
import numpy as np | ||
import torch | ||
from scipy.signal import correlate | ||
|
||
|
||
def compute_normalized_cross_correlation(signal1: torch.Tensor, signal2: torch.Tensor) -> torch.Tensor: | ||
"""Calculate the normalized cross-correlation between two signals. | ||
|
||
Args: | ||
signal1 (torch.Tensor): The first input signal as a PyTorch tensor. | ||
signal2 (torch.Tensor): The second input signal as a PyTorch tensor. | ||
|
||
Returns: | ||
torch.Tensor: The normalized cross-correlation value between the two input signals. | ||
|
||
Examples: | ||
>>> signal1 = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0]) | ||
>>> signal2 = torch.tensor([2.0, 3.0, 4.0]) | ||
>>> normalized_cross_correlation(signal1, signal2) | ||
Tensor([0.30151134, 0.51298918, 0.77459667, 0.9486833 , 0.90453403, 0.70710678, 0.43643578]) | ||
|
||
Note: | ||
This function assumes the input signals are one-dimensional | ||
and contain sufficient elements for meaningful cross-correlation. | ||
""" | ||
# Ensure the inputs are 1D tensors | ||
if signal1.ndim != 1 or signal2.ndim != 1: | ||
raise ValueError("Input signals must be one-dimensional") | ||
|
||
# Convert PyTorch tensors to NumPy arrays | ||
signal1 = signal1.numpy() | ||
signal2 = signal2.numpy() | ||
|
||
# Calculate the energy of each signal | ||
energy_signal1 = np.sum(signal1**2) | ||
energy_signal2 = np.sum(signal2**2) | ||
|
||
# Check for zero energy to avoid division by zero | ||
if energy_signal1 == 0 or energy_signal2 == 0: | ||
raise ZeroDivisionError("One of the input signals has zero energy, causing division by zero in normalization") | ||
|
||
# Compute the cross-correlation | ||
cross_correlation = correlate(signal1, signal2) | ||
|
||
# Calculate the normalized cross-correlation | ||
normalized_cross_correlation = cross_correlation / np.sqrt(energy_signal1 * energy_signal2) | ||
|
||
print(normalized_cross_correlation) | ||
return torch.Tensor(normalized_cross_correlation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""This module defines a pydra API for computing cross correlation between two signals.""" | ||
|
||
import pydra | ||
|
||
from senselab.utils.tasks.cross_correlation import compute_normalized_cross_correlation | ||
|
||
compute_normalized_cross_correlation_pt = pydra.mark.task(compute_normalized_cross_correlation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
"""This module implements some utilities for computing the Equal Error Rate (EER).""" | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
from speechbrain.utils.metric_stats import EER | ||
|
||
|
||
def compute_eer(predictions: torch.Tensor, targets: torch.Tensor) -> Tuple[float, float]: | ||
"""Compute the Equal Error Rate (EER). | ||
|
||
Args: | ||
predictions (torch.Tensor): A 1D tensor of predictions. | ||
targets (torch.Tensor): A 1D tensor of targets. | ||
|
||
Returns: | ||
Tuple[float, float]: The EER and the threshold for the EER. | ||
""" | ||
return EER(predictions, targets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""This module defines a pydra API for computing EER.""" | ||
|
||
import pydra | ||
|
||
from senselab.utils.tasks.eer import compute_eer | ||
|
||
compute_eer_pt = pydra.mark.task(compute_eer) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we've broken up a lot of the tasks into individual files (e.g. individual modules) and I'm wondering if we should consolidate a bit to make it easier to use. Not super opinionated on this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a good point. I think we can proceed like this for this first release and next week we do some restructure