Skip to content

Commit

Permalink
changing device policy for data augmentation #2
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiocat93 committed Oct 1, 2024
1 parent 6a59426 commit 1552144
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/senselab/audio/tasks/data_augmentation/api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module provides the API for data augmentation tasks."""

from typing import List, Union
from typing import List, Optional, Union

from audiomentations import Compose as AudiomentationsCompose
from torch_audiomentations import Compose as TorchAudiomentationsCompose
Expand All @@ -14,7 +14,7 @@
def augment_audios(
audios: List[Audio],
augmentation: Union[TorchAudiomentationsCompose, AudiomentationsCompose],
device: DeviceType = DeviceType.CPU,
device: Optional[DeviceType] = None,
) -> List[Audio]:
"""Augments all provided audios.
Expand All @@ -25,6 +25,7 @@ def augment_audios(
audios: List of Audios whose data will be augmented with the given augmentations.
augmentation: A composition of augmentations (torch-audiomentations or audiomentations).
device: The device to use for augmenting (relevant for torch-audiomentations).
Defaults is None.
Returns:
List of augmented audios.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This module implements some utilities for audio data augmentation with torch_audiomentations."""

from typing import List
from typing import List, Optional

import pydra
import torch
Expand All @@ -15,7 +15,7 @@


def augment_audios_with_torch_audiomentations(
audios: List[Audio], augmentation: Compose, device: DeviceType = DeviceType.CPU
audios: List[Audio], augmentation: Compose, device: Optional[DeviceType] = None
) -> List[Audio]:
"""Augments all provided audios with a given augmentation, either individually or all batched together.
Expand All @@ -30,7 +30,8 @@ def augment_audios_with_torch_audiomentations(
output_type set to "dict"
device: The device to use for augmenting. If the chosen device
is MPS or CUDA then the audios are all batched together, so for optimal performance, batching should
be done by passing a batch_size worth of audios ar a time
be done by passing a batch_size worth of audios ar a time.
Default is None, which will select the device automatically.
Returns:
List of audios that has passed the all of input audios through the provided augmentation. This does
Expand Down

0 comments on commit 1552144

Please sign in to comment.