diff --git a/src/senselab/utils/data_structures/dataset.py b/src/senselab/utils/data_structures/dataset.py index 1b200de7..2ad98a1d 100644 --- a/src/senselab/utils/data_structures/dataset.py +++ b/src/senselab/utils/data_structures/dataset.py @@ -15,7 +15,12 @@ class Participant(BaseModel): - """Data structure for a participant in a dataset.""" + """Data structure for a participant in a dataset. + + Attributes: + id: The ID of the participant. If not provided, generated using uuid4 + metadata: Dictionary of user specified metadata related to the participant + """ id: str = Field(default_factory=lambda: str(uuid.uuid4())) metadata: Dict = Field(default={}) @@ -33,7 +38,12 @@ def __eq__(self, other: object) -> bool: class Session(BaseModel): - """Data structure for a session in a dataset.""" + """Data structure for a session in a dataset. + + Attributes: + id: the ID of the session. If not provided, generated using uuid4 + metadata: Dictionary of user specified metadata related to the session + """ id: str = Field(default_factory=lambda: str(uuid.uuid4())) metadata: Dict = Field(default={}) @@ -60,9 +70,9 @@ class SenselabDataset(BaseModel): audios: List of Audios that are generated based on list of audio filepaths videos: List of Videos generated from a list of video filepaths metadata: Metadata related to the dataset overall but not necessarily the metadata of - indivudal audios in the dataset - sessions: Session ID mapping to Session instance - participants: Mapping of participant ID to a Participant instance + indivudal audios or videos in the dataset + sessions: Mapping of Session IDs to Session instances + participants: Mapping of participant IDs to Participant instances """ participants: Dict[str, Participant] = Field(default_factory=dict) @@ -136,6 +146,8 @@ def create_bids_dataset(cls, bids_root_filepath: str) -> "SenselabDataset": """ pass + # TODO Decide if this method and audio_merge_from_pydra_task should be defined elsewhere, like in + # a Pydra helper class def create_audio_split_for_pydra_task(self, batch_size: int = 1) -> List[List[Audio]]: """Splits the audio data for Pydra tasks. @@ -165,7 +177,7 @@ def create_audio_split_for_pydra_task(self, batch_size: int = 1) -> List[List[Au return [[audio] for audio in self.audios] def audio_merge_from_pydra_task(self, audios_to_merge: List[List[Audio]]) -> None: - """Write later. + """TODO: Write later. Logic Pydra: audios: List of audios that want to give to task @@ -180,26 +192,64 @@ def audio_merge_from_pydra_task(self, audios_to_merge: List[List[Audio]]) -> Non self.audios.append(audio_output) def add_participant(self, participant: Participant) -> None: - """Add a participant to the dataset.""" + """Add a participant to the dataset. + + Adds a new participant to the dataset if they are not already in it. + + Args: + participant: instance of a Participant that we want to add to the dataset + + Raises: + ValueError: If the participant ID is already in the dataset, we raise a value error. + This means that either the ID is non-unique and/or the participant is already in the dataset. + """ if participant.id in self.participants: raise ValueError(f"Participant with ID {participant.id} already exists.") self.participants[participant.id] = participant def add_session(self, session: Session) -> None: - """Add a session to the dataset.""" + """Add a session to the dataset. + + Adds a new sesszion to the dataset if it is not already in there. + + Args: + session: instance of a Session that we want to add to the dataset + + Raises: + ValueError: If the session ID is already in the dataset, we raise a value error. + This means that either the ID is non-unique and/or the session is already in the dataset. + """ if session.id in self.sessions: raise ValueError(f"Session with ID {session.id} already exists.") self.sessions[session.id] = session def get_participants(self) -> List[Participant]: - """Get the list of participants in the dataset.""" + """Get the list of participants in the dataset. + + Returns: + participants (List[Participant]): all of the instances of participants in the dataset + Warning: The instances are returned as is, so changes to the underlying participants in this + list will automatically be reflected in the dataset. + """ return list(self.participants.values()) def get_sessions(self) -> List[Session]: - """Get the list of sessions in the dataset.""" + """Get the list of sessions in the dataset. + + Returns: + sessions (List[Session]): all of the instances of sessions in the dataset + Warning: The instances are returned as is, so changes to the underlying sessions in this + list will automatically be reflected in the dataset. + """ return list(self.sessions.values()) def _get_dict_representation(self) -> Dict: + """Internal function for generating a dictionary representation of the dataset. + + Returns: + Generates a dictionary representation of the dataset where the keys are participants, + sessions, audios, videos, and metadata. + """ audio_data: Dict[str, List] = {} video_data: Dict[str, List] = {} senselab_dict: Dict[str, Union[Dict[str, List], List]] = { @@ -268,7 +318,14 @@ def _get_dict_representation(self) -> Dict: return senselab_dict def convert_senselab_dataset_to_hf_datasets(self) -> Dict[str, Dataset]: - """Converts Senselab datasets into HuggingFace datasets.""" + """Converts Senselab datasets into HuggingFace datasets. + + Returns: + A dictionary of HuggingFace datasets that represent the underlying Senselab dataset. + Currently only supports creating HuggingFace datasets for the Audio(s) and Video(s) in + the SenselabDataset. Videos in HuggingFace are not natively supported, so they are treated + as Sequences of images with a frame_rate. + """ senselab_dict = self._get_dict_representation() # print(senselab_dict['videos']['audio'][0]) @@ -297,34 +354,75 @@ def convert_senselab_dataset_to_hf_datasets(self) -> Dict[str, Dataset]: return hf_datasets @classmethod - def convert_hf_dataset_to_senselab_dataset(cls, hf_datasets: Dict[str, Dataset]) -> "SenselabDataset": - """Converts HuggingFace dataset to a Senselab dataset.""" + def convert_hf_dataset_to_senselab_dataset( + cls, hf_datasets: Dict[str, Dataset], metadata: Dict = {}, transfer_metadata: bool = False + ) -> "SenselabDataset": + """Converts HuggingFace dataset to a Senselab dataset. + + Convert HuggingFace dataset(s) to a SenselabDataset where each component of a SenselabDataset + (e.g. audios, videos, sessions, participants) are stored under different keys in the provided dictionary. + + Args: + hf_datasets: Dictionary of the different individual components that make up a SenselabDataset. Audios must + be organized as HuggingFace Audio(s) which is a dictionary with an array attribute, sampling rate, and + a path. Videos are not natively supported by HuggingFace, so SenselabDataset expects a sequence of + frames that are each an image, a frame_rate argument, and optionally the associated audio with the + video. + metadata: Dictionary of additional dataset level metadata. Differs from metadata inside each HuggingFace + dataset that is provided and is unaffected by transder_metadata. + transfer_metadata: Specifies whether to generate metadata from extraneous attributes in the + HuggingDataset in addition to any specified ina metadata field, otherwise they are ignored. + Defaults to False. + + Returns: + The generated SenselabDataset from the provided fields in the hf_datasets dictionary. Currently does not + support converting sessions or participants. + """ audios = [] videos = [] sessions: Dict[str, Session] = {} participants: Dict[str, Participant] = {} + if "audios" in hf_datasets: audio_dataset = hf_datasets["audios"] - for audio in audio_dataset: + for audio in audio_dataset: # Equivalent of running over each row in the Dataset + audio_metadata = audio["metadata"] if "metadata" in audio else {} + if transfer_metadata: + for feature in audio_dataset.features: + if feature == "metadata" or feature == "audio": + continue + audio_metadata[feature] = audio[feature] audios.append( Audio( waveform=audio["audio"]["array"], sampling_rate=audio["audio"]["sampling_rate"], orig_path_or_id=audio["audio"]["path"], - metadata=audio["metadata"] if "metadata" in audio else {}, + metadata=audio_metadata, ) ) if "videos" in hf_datasets: video_dataset = hf_datasets["videos"] for video in video_dataset: + video_metadata = video["metadata"] if "metadata" in video else {} + if transfer_metadata: + for feature in video_dataset.features: + if ( + feature == "metadata" + or feature == "frames" + or feature == "frame_rate" + or feature == "path" + or feature == "audio" + ): + continue + video_metadata[feature] = video[feature] videos.append( Video( frames=video["frames"]["image"], frame_rate=video["frame_rate"], - metadata=video["metadata"] if "metadata" in video else {}, + metadata=video_metadata, orig_path_or_id=video["path"], - audio=Audio( + audio=Audio( # Assumes audio metadata is stored a level higher within the video's metadata waveform=video["audio"]["array"], sampling_rate=video["audio"]["sampling_rate"], orig_path_or_id=video["audio"]["path"], @@ -338,7 +436,9 @@ def convert_hf_dataset_to_senselab_dataset(cls, hf_datasets: Dict[str, Dataset]) if "participants" in hf_datasets: pass - return SenselabDataset(participants=participants, sessions=sessions, audios=audios, videos=videos) + return SenselabDataset( + participants=participants, sessions=sessions, audios=audios, videos=videos, metadata=metadata.copy() + ) def __eq__(self, other: object) -> bool: """Overloads the default BaseModel equality to correctly check that datasets are equivalent.""" diff --git a/src/tests/utils/data_structures/dataset_test.py b/src/tests/utils/data_structures/dataset_test.py index 55a7ca0c..68d8b576 100644 --- a/src/tests/utils/data_structures/dataset_test.py +++ b/src/tests/utils/data_structures/dataset_test.py @@ -4,6 +4,7 @@ import pytest import torch import torchaudio +from datasets import load_dataset from senselab.audio.data_structures.audio import Audio from senselab.utils.data_structures.dataset import Participant, SenselabDataset, Session @@ -227,3 +228,23 @@ def test_convert_senselab_dataset_to_hf_datasets() -> None: reconverted_dataset.videos[0].audio.waveform, dataset.videos[0].audio.waveform, rtol=0, atol=1e-4 ) assert reconverted_dataset.videos[0].frame_rate == dataset.videos[0].frame_rate + + +def test_convert_hf_dataset_to_senselab_dataset() -> None: + """Use an existing HF dataset to show that Senselab properly converts and maintains a HF Dataset.""" + ravdness = load_dataset("xbgoose/ravdess", split="train") + ravdness_features = list(ravdness.features) + ravdness_features.remove("audio") + if "metadata" in ravdness_features: + ravdness_features.remove("metadata") + senselab_ravdness = SenselabDataset.convert_hf_dataset_to_senselab_dataset( + {"audios": ravdness}, transfer_metadata=True + ) + + assert len(senselab_ravdness.audios) == 1440 + assert set(senselab_ravdness.audios[0].metadata.keys()) == set(ravdness_features) + + senselab_ravdness = SenselabDataset.convert_hf_dataset_to_senselab_dataset({"audios": ravdness}) + + assert len(senselab_ravdness.audios) == 1440 + assert senselab_ravdness.audios[0].metadata == {}