diff --git a/src/senselab/audio/tasks/classification/speech_emotion_recognition.py b/src/senselab/audio/tasks/classification/speech_emotion_recognition.py index 368a486e..29261ac4 100644 --- a/src/senselab/audio/tasks/classification/speech_emotion_recognition.py +++ b/src/senselab/audio/tasks/classification/speech_emotion_recognition.py @@ -6,18 +6,20 @@ from transformers import AutoConfig, pipeline from senselab.audio.data_structures.audio import Audio +from senselab.utils.data_structures.device import DeviceType, _select_device_and_dtype from senselab.utils.data_structures.model import HFModel -def audio_classification_with_hf_models(audios: List[Audio], model: HFModel) -> List[List[Dict]]: +def audio_classification_with_hf_models(audios: List[Audio], model: HFModel, batch_size: int = 8) -> List[List[Dict]]: """General audio classification functionality utilitzing HuggingFace pipelines. Classifies all audios, with no underlying assumptions on what the classification labels are, and returns the output that the pipeline gives. Args: - audios: List of Audio objects that we want to run classification on - model: The HuggingFace model that will be used for running the inference + audios: List of Audio objects that we want to run classification on. + model: The HuggingFace model that will be used for running the inference. + batch_size: The size of the batches for processing. Returns: List of Lists of Dictionaries where each corresponds to the audio that it was ran on and the List of @@ -37,13 +39,24 @@ def audio_classification_with_hf_models(audios: List[Audio], model: HFModel) -> UserWarning(f"The model '{model.path_or_uri}' has not been tagged as an Inference Endpoint and \ so we cannot guarantee its input and outputs are as expected") ) - - classification_pipeline = pipeline(task="audio-classification", model=model.path_or_uri, revision=model.revision) + device, _ = _select_device_and_dtype(compatible_devices=[DeviceType.CUDA, DeviceType.CPU]) + classification_pipeline = pipeline( + task="audio-classification", + model=model.path_or_uri, + revision=model.revision, + device=0 if device == DeviceType.CUDA else -1, + ) + + # Convert audio waveforms to a format suitable for the pipeline + waveforms = [ + audio.waveform.numpy().squeeze() if audio.waveform.shape[0] > 1 else audio.waveform.numpy().squeeze() + for audio in audios + ] + + # Run the classification pipeline in batches classification_outputs = [] - - # TODO: figure out adding batching and GPU support - for audio in audios: - classification_outputs.append(classification_pipeline(audio.waveform.numpy().squeeze())) + for output in classification_pipeline(waveforms, batch_size=batch_size, truncation="only_first"): + classification_outputs.append(output) return classification_outputs @@ -87,7 +100,6 @@ def speech_emotion_recognition_with_hf_models(audios: List[Audio], model: HFMode ) audio_classifications = audio_classification_with_hf_models(audios, model) - # print(audio_classifications) ser_output = [] for classification in audio_classifications: classification_output = {} diff --git a/src/tests/audio/tasks/classification_test.py b/src/tests/audio/tasks/classification_test.py index dd35776c..3b3e2788 100644 --- a/src/tests/audio/tasks/classification_test.py +++ b/src/tests/audio/tasks/classification_test.py @@ -2,12 +2,46 @@ import os +import pytest + from senselab.audio.data_structures.audio import Audio -from senselab.audio.tasks.classification.speech_emotion_recognition import speech_emotion_recognition_with_hf_models +from senselab.audio.tasks.classification.speech_emotion_recognition import ( + audio_classification_with_hf_models, + speech_emotion_recognition_with_hf_models, +) from senselab.utils.data_structures.model import HFModel if os.getenv("GITHUB_ACTIONS") != "true": + @pytest.fixture + def valid_model() -> HFModel: + """Fixture for generating a valid HFModel.""" + return HFModel(path_or_uri="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", revision="main") + + def test_audio_classification_with_hf_models(resampled_mono_audio_sample: Audio) -> None: + """Tests the audio classification functionality with HuggingFace models. + + This test uses a real HuggingFace model and pipeline to classify a dummy audio sample. + It verifies that the classification function processes the input correctly and returns + the expected output. + + Args: + resampled_mono_audio_sample: A fixture that provides a dummy Audio object. + """ + # Real model + model = HFModel(path_or_uri="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", revision="main") + + # Run the classification function + result = audio_classification_with_hf_models([resampled_mono_audio_sample], model) + + # Verify the result + assert len(result) == 1 + assert isinstance(result[0], list) + assert len(result[0]) > 0 # Ensure there's at least one classification result + assert isinstance(result[0][0], dict) + assert "label" in result[0][0] + assert "score" in result[0][0] + def test_speech_emotion_recognition(resampled_mono_audio_sample: Audio) -> None: """Tests speech emotion recognition.""" # Discrete test @@ -32,4 +66,37 @@ def test_speech_emotion_recognition(resampled_mono_audio_sample: Audio) -> None: arousal, valence, or dominance" assert set(continuous_values.keys()) == set(["arousal", "valence", "dominance"]) -# TODO add tests + def test_speech_emotion_recognition_stereo_raises_value_error(resampled_stereo_audio_sample: Audio) -> None: + """Tests that speech emotion recognition raises ValueError with stereo audio samples.""" + resampled_stereo_audio_samples = [resampled_stereo_audio_sample] + + with pytest.raises(ValueError, match="We expect a single channel audio input for AudioClassificationPipeline"): + speech_emotion_recognition_with_hf_models( + resampled_stereo_audio_samples, + HFModel(path_or_uri="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"), + ) + + def test_batch_processing_consistency(resampled_mono_audio_sample: Audio, valid_model: HFModel) -> None: + """Test batch processing consistency for different batch sizes.""" + audios = [resampled_mono_audio_sample] * 3 # Duplicate the sample to create a list + result_batch_1 = audio_classification_with_hf_models(audios, valid_model, batch_size=1) + result_batch_5 = audio_classification_with_hf_models(audios, valid_model, batch_size=5) + result_batch_10 = audio_classification_with_hf_models(audios, valid_model, batch_size=10) + assert len(result_batch_1) == len(result_batch_10) == len(result_batch_5) + + def test_speech_emotion_recognition_with_correct_labels( + resampled_mono_audio_sample: Audio, valid_model: HFModel + ) -> None: + """Test that the emotion recognition output contains expected emotion labels.""" + result = speech_emotion_recognition_with_hf_models([resampled_mono_audio_sample], valid_model) + assert len(result) == 1 + assert isinstance(result[0], tuple) + assert isinstance(result[0][0], str) + assert isinstance(result[0][1], dict) + + expected_emotions = ["happy", "sad", "neutral", "positive", "negative", "anger", "disgust", "fear"] + for emotion in expected_emotions: + if emotion in result[0][1]: + break + else: + pytest.fail("None of the expected emotion labels found in the output.")