Skip to content

Commit

Permalink
Merge pull request #133 from shahules786/dev#91
Browse files Browse the repository at this point in the history
Padding transform
  • Loading branch information
iver56 authored Apr 20, 2022
2 parents 7482d04 + a3a96d5 commit df57948
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 0 deletions.
134 changes: 134 additions & 0 deletions tests/test_padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import unittest
import numpy as np
import torch
from numpy.testing import assert_almost_equal
import pytest

from torch_audiomentations.augmentations.padding import Padding


class TestPadding(unittest.TestCase):
def test_padding_end(self):

audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32)
augment = Padding(
min_fraction=0.2,
max_fraction=0.5,
pad_section="end",
p=1.0,
output_type="dict",
)
padded_samples = augment(audio_samples).samples

self.assertEqual(audio_samples.shape, padded_samples.shape)
assert_almost_equal(padded_samples[..., -6400:].numpy(), np.zeros((2, 2, 6400)))

def test_padding_start(self):

audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32)
augment = Padding(
min_fraction=0.2,
max_fraction=0.5,
pad_section="start",
p=1.0,
output_type="dict",
)
padded_samples = augment(audio_samples).samples

self.assertEqual(audio_samples.shape, padded_samples.shape)
assert_almost_equal(padded_samples[..., :6400].numpy(), np.zeros((2, 2, 6400)))

def test_padding_zero(self):

audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32)
augment = Padding(min_fraction=0.2, max_fraction=0.5, p=0.0, output_type="dict")
padded_samples = augment(audio_samples).samples

self.assertEqual(audio_samples.shape, padded_samples.shape)
assert_almost_equal(audio_samples.numpy(), padded_samples.numpy())

def test_padding_perexample(self):

audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32)
augment = Padding(
min_fraction=0.2,
max_fraction=0.5,
pad_section="start",
p=0.5,
mode="per_example",
p_mode="per_example",
output_type="dict",
)

padded_samples = augment(audio_samples).samples.numpy()
num_unprocessed_examples = 0.0
num_processed_examples = 0.0
for i, sample in enumerate(padded_samples):
if np.allclose(audio_samples[i], sample):
num_unprocessed_examples += 1
else:
num_processed_examples += 1

self.assertLess(padded_samples.sum(), audio_samples.numpy().sum())

def test_padding_perchannel(self):

audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32)
augment = Padding(
min_fraction=0.2,
max_fraction=0.5,
pad_section="start",
p=0.5,
mode="per_channel",
p_mode="per_channel",
output_type="dict",
)

padded_samples = augment(audio_samples).samples.numpy()
num_unprocessed_examples = 0.0
num_processed_examples = 0.0
for i, sample in enumerate(padded_samples):
if np.allclose(audio_samples[i], sample):
num_unprocessed_examples += 1
else:
num_processed_examples += 1

self.assertLess(padded_samples.sum(), audio_samples.numpy().sum())

def test_padding_variability_perexample(self):

audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32)
augment = Padding(
min_fraction=0.2,
max_fraction=0.5,
pad_section="start",
p=0.5,
mode="per_example",
p_mode="per_example",
output_type="dict",
)

padded_samples = augment(audio_samples).samples.numpy()
num_unprocessed_examples = 0.0
num_processed_examples = 0.0
for i, sample in enumerate(padded_samples):
if np.allclose(audio_samples[i], sample):
num_unprocessed_examples += 1
else:
num_processed_examples += 1

self.assertEqual(num_processed_examples + num_unprocessed_examples, 10)
self.assertGreater(num_processed_examples, 2)
self.assertLess(num_unprocessed_examples, 8)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_padding_cuda(self):

audio_samples = torch.rand(
size=(2, 2, 32000), dtype=torch.float32, device=torch.device("cuda")
)
augment = Padding(min_fraction=0.2, max_fraction=0.5, p=1.0, output_type="dict")
padded_samples = augment(audio_samples).samples

self.assertEqual(audio_samples.shape, padded_samples.shape)
assert_almost_equal(padded_samples[..., -6400:].cpu().numpy(), np.zeros((2, 2, 6400)))
85 changes: 85 additions & 0 deletions torch_audiomentations/augmentations/padding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
from typing import Optional
from torch import Tensor

from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict


class Padding(BaseWaveformTransform):

supported_modes = {"per_batch", "per_example", "per_channel"}
supports_multichannel = True
requires_sample_rate = False

supports_target = True
requires_target = False

def __init__(
self,
min_fraction=0.1,
max_fraction=0.5,
pad_section="end",
mode="per_batch",
p=0.5,
p_mode: Optional[str] = None,
sample_rate: Optional[int] = None,
target_rate: Optional[int] = None,
output_type: Optional[str] = None,
):
super().__init__(
mode=mode,
p=p,
p_mode=p_mode,
sample_rate=sample_rate,
target_rate=target_rate,
output_type=output_type,
)
self.min_fraction = min_fraction
self.max_fraction = max_fraction
self.pad_section = pad_section
if not self.min_fraction >= 0.0:
raise ValueError("minimum fraction should be greater than zero.")
if self.min_fraction > self.max_fraction:
raise ValueError(
"minimum fraction should be less than or equal to maximum fraction."
)
assert self.pad_section in (
"start",
"end",
), 'pad_section must be "start" or "end"'

def randomize_parameters(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
):
input_length = samples.shape[-1]
self.transform_parameters["pad_length"] = torch.randint(
int(input_length * self.min_fraction),
int(input_length * self.max_fraction),
(samples.shape[0],),
)

def apply_transform(
self,
samples: Tensor,
sample_rate: Optional[int] = None,
targets: Optional[int] = None,
target_rate: Optional[int] = None,
) -> ObjectDict:

for i, index in enumerate(self.transform_parameters["pad_length"]):
if self.pad_section == "start":
samples[i, :, :index] = 0.0
else:
samples[i, :, -index:] = 0.0

return ObjectDict(
samples=samples,
sample_rate=sample_rate,
targets=targets,
target_rate=target_rate,
)

0 comments on commit df57948

Please sign in to comment.