Skip to content

Commit

Permalink
Add 2 more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ibevers committed Jul 19, 2024
1 parent ecc0f70 commit 3ed0b60
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,20 @@
from transformers import AutoConfig, pipeline

from senselab.audio.data_structures.audio import Audio
from senselab.utils.data_structures.device import DeviceType, _select_device_and_dtype
from senselab.utils.data_structures.model import HFModel


def audio_classification_with_hf_models(audios: List[Audio], model: HFModel) -> List[List[Dict]]:
def audio_classification_with_hf_models(audios: List[Audio], model: HFModel, batch_size: int = 8) -> List[List[Dict]]:
"""General audio classification functionality utilitzing HuggingFace pipelines.
Classifies all audios, with no underlying assumptions on what the classification labels are,
and returns the output that the pipeline gives.
Args:
audios: List of Audio objects that we want to run classification on
model: The HuggingFace model that will be used for running the inference
audios: List of Audio objects that we want to run classification on.
model: The HuggingFace model that will be used for running the inference.
batch_size: The size of the batches for processing.
Returns:
List of Lists of Dictionaries where each corresponds to the audio that it was ran on and the List of
Expand All @@ -37,13 +39,24 @@ def audio_classification_with_hf_models(audios: List[Audio], model: HFModel) ->
UserWarning(f"The model '{model.path_or_uri}' has not been tagged as an Inference Endpoint and \
so we cannot guarantee its input and outputs are as expected")
)

classification_pipeline = pipeline(task="audio-classification", model=model.path_or_uri, revision=model.revision)
device, _ = _select_device_and_dtype(compatible_devices=[DeviceType.CUDA, DeviceType.CPU])
classification_pipeline = pipeline(
task="audio-classification",
model=model.path_or_uri,
revision=model.revision,
device=0 if device == DeviceType.CUDA else -1,
)

# Convert audio waveforms to a format suitable for the pipeline
waveforms = [
audio.waveform.numpy().squeeze() if audio.waveform.shape[0] > 1 else audio.waveform.numpy().squeeze()
for audio in audios
]

# Run the classification pipeline in batches
classification_outputs = []

# TODO: figure out adding batching and GPU support
for audio in audios:
classification_outputs.append(classification_pipeline(audio.waveform.numpy().squeeze()))
for output in classification_pipeline(waveforms, batch_size=batch_size, truncation="only_first"):
classification_outputs.append(output)

return classification_outputs

Expand Down Expand Up @@ -87,7 +100,6 @@ def speech_emotion_recognition_with_hf_models(audios: List[Audio], model: HFMode
)

audio_classifications = audio_classification_with_hf_models(audios, model)
# print(audio_classifications)
ser_output = []
for classification in audio_classifications:
classification_output = {}
Expand Down
71 changes: 69 additions & 2 deletions src/tests/audio/tasks/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,46 @@

import os

import pytest

from senselab.audio.data_structures.audio import Audio
from senselab.audio.tasks.classification.speech_emotion_recognition import speech_emotion_recognition_with_hf_models
from senselab.audio.tasks.classification.speech_emotion_recognition import (
audio_classification_with_hf_models,
speech_emotion_recognition_with_hf_models,
)
from senselab.utils.data_structures.model import HFModel

if os.getenv("GITHUB_ACTIONS") != "true":

@pytest.fixture
def valid_model() -> HFModel:
"""Fixture for generating a valid HFModel."""
return HFModel(path_or_uri="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", revision="main")

def test_audio_classification_with_hf_models(resampled_mono_audio_sample: Audio) -> None:
"""Tests the audio classification functionality with HuggingFace models.
This test uses a real HuggingFace model and pipeline to classify a dummy audio sample.
It verifies that the classification function processes the input correctly and returns
the expected output.
Args:
resampled_mono_audio_sample: A fixture that provides a dummy Audio object.
"""
# Real model
model = HFModel(path_or_uri="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", revision="main")

# Run the classification function
result = audio_classification_with_hf_models([resampled_mono_audio_sample], model)

# Verify the result
assert len(result) == 1
assert isinstance(result[0], list)
assert len(result[0]) > 0 # Ensure there's at least one classification result
assert isinstance(result[0][0], dict)
assert "label" in result[0][0]
assert "score" in result[0][0]

def test_speech_emotion_recognition(resampled_mono_audio_sample: Audio) -> None:
"""Tests speech emotion recognition."""
# Discrete test
Expand All @@ -32,4 +66,37 @@ def test_speech_emotion_recognition(resampled_mono_audio_sample: Audio) -> None:
arousal, valence, or dominance"
assert set(continuous_values.keys()) == set(["arousal", "valence", "dominance"])

# TODO add tests
def test_speech_emotion_recognition_stereo_raises_value_error(resampled_stereo_audio_sample: Audio) -> None:
"""Tests that speech emotion recognition raises ValueError with stereo audio samples."""
resampled_stereo_audio_samples = [resampled_stereo_audio_sample]

with pytest.raises(ValueError, match="We expect a single channel audio input for AudioClassificationPipeline"):
speech_emotion_recognition_with_hf_models(
resampled_stereo_audio_samples,
HFModel(path_or_uri="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"),
)

def test_batch_processing_consistency(resampled_mono_audio_sample: Audio, valid_model: HFModel) -> None:
"""Test batch processing consistency for different batch sizes."""
audios = [resampled_mono_audio_sample] * 3 # Duplicate the sample to create a list
result_batch_1 = audio_classification_with_hf_models(audios, valid_model, batch_size=1)
result_batch_5 = audio_classification_with_hf_models(audios, valid_model, batch_size=5)
result_batch_10 = audio_classification_with_hf_models(audios, valid_model, batch_size=10)
assert len(result_batch_1) == len(result_batch_10) == len(result_batch_5)

def test_speech_emotion_recognition_with_correct_labels(
resampled_mono_audio_sample: Audio, valid_model: HFModel
) -> None:
"""Test that the emotion recognition output contains expected emotion labels."""
result = speech_emotion_recognition_with_hf_models([resampled_mono_audio_sample], valid_model)
assert len(result) == 1
assert isinstance(result[0], tuple)
assert isinstance(result[0][0], str)
assert isinstance(result[0][1], dict)

expected_emotions = ["happy", "sad", "neutral", "positive", "negative", "anger", "disgust", "fear"]
for emotion in expected_emotions:
if emotion in result[0][1]:
break
else:
pytest.fail("None of the expected emotion labels found in the output.")

0 comments on commit 3ed0b60

Please sign in to comment.