Skip to content

Commit

Permalink
Merge pull request #149 from sensein/dataset_update
Browse files Browse the repository at this point in the history
Dataset class update (adding functionalities plus adjusting the documentation)
  • Loading branch information
fabiocat93 authored Aug 16, 2024
2 parents 41dc1b1 + 9537fc0 commit 9e693a9
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 18 deletions.
136 changes: 118 additions & 18 deletions src/senselab/utils/data_structures/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@


class Participant(BaseModel):
"""Data structure for a participant in a dataset."""
"""Data structure for a participant in a dataset.
Attributes:
id: The ID of the participant. If not provided, generated using uuid4
metadata: Dictionary of user specified metadata related to the participant
"""

id: str = Field(default_factory=lambda: str(uuid.uuid4()))
metadata: Dict = Field(default={})
Expand All @@ -33,7 +38,12 @@ def __eq__(self, other: object) -> bool:


class Session(BaseModel):
"""Data structure for a session in a dataset."""
"""Data structure for a session in a dataset.
Attributes:
id: the ID of the session. If not provided, generated using uuid4
metadata: Dictionary of user specified metadata related to the session
"""

id: str = Field(default_factory=lambda: str(uuid.uuid4()))
metadata: Dict = Field(default={})
Expand All @@ -60,9 +70,9 @@ class SenselabDataset(BaseModel):
audios: List of Audios that are generated based on list of audio filepaths
videos: List of Videos generated from a list of video filepaths
metadata: Metadata related to the dataset overall but not necessarily the metadata of
indivudal audios in the dataset
sessions: Session ID mapping to Session instance
participants: Mapping of participant ID to a Participant instance
indivudal audios or videos in the dataset
sessions: Mapping of Session IDs to Session instances
participants: Mapping of participant IDs to Participant instances
"""

participants: Dict[str, Participant] = Field(default_factory=dict)
Expand Down Expand Up @@ -136,6 +146,8 @@ def create_bids_dataset(cls, bids_root_filepath: str) -> "SenselabDataset":
"""
pass

# TODO Decide if this method and audio_merge_from_pydra_task should be defined elsewhere, like in
# a Pydra helper class
def create_audio_split_for_pydra_task(self, batch_size: int = 1) -> List[List[Audio]]:
"""Splits the audio data for Pydra tasks.
Expand Down Expand Up @@ -165,7 +177,7 @@ def create_audio_split_for_pydra_task(self, batch_size: int = 1) -> List[List[Au
return [[audio] for audio in self.audios]

def audio_merge_from_pydra_task(self, audios_to_merge: List[List[Audio]]) -> None:
"""Write later.
"""TODO: Write later.
Logic Pydra:
audios: List of audios that want to give to task
Expand All @@ -180,26 +192,64 @@ def audio_merge_from_pydra_task(self, audios_to_merge: List[List[Audio]]) -> Non
self.audios.append(audio_output)

def add_participant(self, participant: Participant) -> None:
"""Add a participant to the dataset."""
"""Add a participant to the dataset.
Adds a new participant to the dataset if they are not already in it.
Args:
participant: instance of a Participant that we want to add to the dataset
Raises:
ValueError: If the participant ID is already in the dataset, we raise a value error.
This means that either the ID is non-unique and/or the participant is already in the dataset.
"""
if participant.id in self.participants:
raise ValueError(f"Participant with ID {participant.id} already exists.")
self.participants[participant.id] = participant

def add_session(self, session: Session) -> None:
"""Add a session to the dataset."""
"""Add a session to the dataset.
Adds a new sesszion to the dataset if it is not already in there.
Args:
session: instance of a Session that we want to add to the dataset
Raises:
ValueError: If the session ID is already in the dataset, we raise a value error.
This means that either the ID is non-unique and/or the session is already in the dataset.
"""
if session.id in self.sessions:
raise ValueError(f"Session with ID {session.id} already exists.")
self.sessions[session.id] = session

def get_participants(self) -> List[Participant]:
"""Get the list of participants in the dataset."""
"""Get the list of participants in the dataset.
Returns:
participants (List[Participant]): all of the instances of participants in the dataset
Warning: The instances are returned as is, so changes to the underlying participants in this
list will automatically be reflected in the dataset.
"""
return list(self.participants.values())

def get_sessions(self) -> List[Session]:
"""Get the list of sessions in the dataset."""
"""Get the list of sessions in the dataset.
Returns:
sessions (List[Session]): all of the instances of sessions in the dataset
Warning: The instances are returned as is, so changes to the underlying sessions in this
list will automatically be reflected in the dataset.
"""
return list(self.sessions.values())

def _get_dict_representation(self) -> Dict:
"""Internal function for generating a dictionary representation of the dataset.
Returns:
Generates a dictionary representation of the dataset where the keys are participants,
sessions, audios, videos, and metadata.
"""
audio_data: Dict[str, List] = {}
video_data: Dict[str, List] = {}
senselab_dict: Dict[str, Union[Dict[str, List], List]] = {
Expand Down Expand Up @@ -268,7 +318,14 @@ def _get_dict_representation(self) -> Dict:
return senselab_dict

def convert_senselab_dataset_to_hf_datasets(self) -> Dict[str, Dataset]:
"""Converts Senselab datasets into HuggingFace datasets."""
"""Converts Senselab datasets into HuggingFace datasets.
Returns:
A dictionary of HuggingFace datasets that represent the underlying Senselab dataset.
Currently only supports creating HuggingFace datasets for the Audio(s) and Video(s) in
the SenselabDataset. Videos in HuggingFace are not natively supported, so they are treated
as Sequences of images with a frame_rate.
"""
senselab_dict = self._get_dict_representation()

# print(senselab_dict['videos']['audio'][0])
Expand Down Expand Up @@ -297,34 +354,75 @@ def convert_senselab_dataset_to_hf_datasets(self) -> Dict[str, Dataset]:
return hf_datasets

@classmethod
def convert_hf_dataset_to_senselab_dataset(cls, hf_datasets: Dict[str, Dataset]) -> "SenselabDataset":
"""Converts HuggingFace dataset to a Senselab dataset."""
def convert_hf_dataset_to_senselab_dataset(
cls, hf_datasets: Dict[str, Dataset], metadata: Dict = {}, transfer_metadata: bool = False
) -> "SenselabDataset":
"""Converts HuggingFace dataset to a Senselab dataset.
Convert HuggingFace dataset(s) to a SenselabDataset where each component of a SenselabDataset
(e.g. audios, videos, sessions, participants) are stored under different keys in the provided dictionary.
Args:
hf_datasets: Dictionary of the different individual components that make up a SenselabDataset. Audios must
be organized as HuggingFace Audio(s) which is a dictionary with an array attribute, sampling rate, and
a path. Videos are not natively supported by HuggingFace, so SenselabDataset expects a sequence of
frames that are each an image, a frame_rate argument, and optionally the associated audio with the
video.
metadata: Dictionary of additional dataset level metadata. Differs from metadata inside each HuggingFace
dataset that is provided and is unaffected by transder_metadata.
transfer_metadata: Specifies whether to generate metadata from extraneous attributes in the
HuggingDataset in addition to any specified ina metadata field, otherwise they are ignored.
Defaults to False.
Returns:
The generated SenselabDataset from the provided fields in the hf_datasets dictionary. Currently does not
support converting sessions or participants.
"""
audios = []
videos = []
sessions: Dict[str, Session] = {}
participants: Dict[str, Participant] = {}

if "audios" in hf_datasets:
audio_dataset = hf_datasets["audios"]
for audio in audio_dataset:
for audio in audio_dataset: # Equivalent of running over each row in the Dataset
audio_metadata = audio["metadata"] if "metadata" in audio else {}
if transfer_metadata:
for feature in audio_dataset.features:
if feature == "metadata" or feature == "audio":
continue
audio_metadata[feature] = audio[feature]
audios.append(
Audio(
waveform=audio["audio"]["array"],
sampling_rate=audio["audio"]["sampling_rate"],
orig_path_or_id=audio["audio"]["path"],
metadata=audio["metadata"] if "metadata" in audio else {},
metadata=audio_metadata,
)
)

if "videos" in hf_datasets:
video_dataset = hf_datasets["videos"]
for video in video_dataset:
video_metadata = video["metadata"] if "metadata" in video else {}
if transfer_metadata:
for feature in video_dataset.features:
if (
feature == "metadata"
or feature == "frames"
or feature == "frame_rate"
or feature == "path"
or feature == "audio"
):
continue
video_metadata[feature] = video[feature]
videos.append(
Video(
frames=video["frames"]["image"],
frame_rate=video["frame_rate"],
metadata=video["metadata"] if "metadata" in video else {},
metadata=video_metadata,
orig_path_or_id=video["path"],
audio=Audio(
audio=Audio( # Assumes audio metadata is stored a level higher within the video's metadata
waveform=video["audio"]["array"],
sampling_rate=video["audio"]["sampling_rate"],
orig_path_or_id=video["audio"]["path"],
Expand All @@ -338,7 +436,9 @@ def convert_hf_dataset_to_senselab_dataset(cls, hf_datasets: Dict[str, Dataset])
if "participants" in hf_datasets:
pass

return SenselabDataset(participants=participants, sessions=sessions, audios=audios, videos=videos)
return SenselabDataset(
participants=participants, sessions=sessions, audios=audios, videos=videos, metadata=metadata.copy()
)

def __eq__(self, other: object) -> bool:
"""Overloads the default BaseModel equality to correctly check that datasets are equivalent."""
Expand Down
21 changes: 21 additions & 0 deletions src/tests/utils/data_structures/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
import torch
import torchaudio
from datasets import load_dataset

from senselab.audio.data_structures.audio import Audio
from senselab.utils.data_structures.dataset import Participant, SenselabDataset, Session
Expand Down Expand Up @@ -227,3 +228,23 @@ def test_convert_senselab_dataset_to_hf_datasets() -> None:
reconverted_dataset.videos[0].audio.waveform, dataset.videos[0].audio.waveform, rtol=0, atol=1e-4
)
assert reconverted_dataset.videos[0].frame_rate == dataset.videos[0].frame_rate


def test_convert_hf_dataset_to_senselab_dataset() -> None:
"""Use an existing HF dataset to show that Senselab properly converts and maintains a HF Dataset."""
ravdness = load_dataset("xbgoose/ravdess", split="train")
ravdness_features = list(ravdness.features)
ravdness_features.remove("audio")
if "metadata" in ravdness_features:
ravdness_features.remove("metadata")
senselab_ravdness = SenselabDataset.convert_hf_dataset_to_senselab_dataset(
{"audios": ravdness}, transfer_metadata=True
)

assert len(senselab_ravdness.audios) == 1440
assert set(senselab_ravdness.audios[0].metadata.keys()) == set(ravdness_features)

senselab_ravdness = SenselabDataset.convert_hf_dataset_to_senselab_dataset({"audios": ravdness})

assert len(senselab_ravdness.audios) == 1440
assert senselab_ravdness.audios[0].metadata == {}

0 comments on commit 9e693a9

Please sign in to comment.