-
Notifications
You must be signed in to change notification settings - Fork 88
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #133 from shahules786/dev#91
Padding transform
- Loading branch information
Showing
2 changed files
with
219 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import unittest | ||
import numpy as np | ||
import torch | ||
from numpy.testing import assert_almost_equal | ||
import pytest | ||
|
||
from torch_audiomentations.augmentations.padding import Padding | ||
|
||
|
||
class TestPadding(unittest.TestCase): | ||
def test_padding_end(self): | ||
|
||
audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32) | ||
augment = Padding( | ||
min_fraction=0.2, | ||
max_fraction=0.5, | ||
pad_section="end", | ||
p=1.0, | ||
output_type="dict", | ||
) | ||
padded_samples = augment(audio_samples).samples | ||
|
||
self.assertEqual(audio_samples.shape, padded_samples.shape) | ||
assert_almost_equal(padded_samples[..., -6400:].numpy(), np.zeros((2, 2, 6400))) | ||
|
||
def test_padding_start(self): | ||
|
||
audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32) | ||
augment = Padding( | ||
min_fraction=0.2, | ||
max_fraction=0.5, | ||
pad_section="start", | ||
p=1.0, | ||
output_type="dict", | ||
) | ||
padded_samples = augment(audio_samples).samples | ||
|
||
self.assertEqual(audio_samples.shape, padded_samples.shape) | ||
assert_almost_equal(padded_samples[..., :6400].numpy(), np.zeros((2, 2, 6400))) | ||
|
||
def test_padding_zero(self): | ||
|
||
audio_samples = torch.rand(size=(2, 2, 32000), dtype=torch.float32) | ||
augment = Padding(min_fraction=0.2, max_fraction=0.5, p=0.0, output_type="dict") | ||
padded_samples = augment(audio_samples).samples | ||
|
||
self.assertEqual(audio_samples.shape, padded_samples.shape) | ||
assert_almost_equal(audio_samples.numpy(), padded_samples.numpy()) | ||
|
||
def test_padding_perexample(self): | ||
|
||
audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32) | ||
augment = Padding( | ||
min_fraction=0.2, | ||
max_fraction=0.5, | ||
pad_section="start", | ||
p=0.5, | ||
mode="per_example", | ||
p_mode="per_example", | ||
output_type="dict", | ||
) | ||
|
||
padded_samples = augment(audio_samples).samples.numpy() | ||
num_unprocessed_examples = 0.0 | ||
num_processed_examples = 0.0 | ||
for i, sample in enumerate(padded_samples): | ||
if np.allclose(audio_samples[i], sample): | ||
num_unprocessed_examples += 1 | ||
else: | ||
num_processed_examples += 1 | ||
|
||
self.assertLess(padded_samples.sum(), audio_samples.numpy().sum()) | ||
|
||
def test_padding_perchannel(self): | ||
|
||
audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32) | ||
augment = Padding( | ||
min_fraction=0.2, | ||
max_fraction=0.5, | ||
pad_section="start", | ||
p=0.5, | ||
mode="per_channel", | ||
p_mode="per_channel", | ||
output_type="dict", | ||
) | ||
|
||
padded_samples = augment(audio_samples).samples.numpy() | ||
num_unprocessed_examples = 0.0 | ||
num_processed_examples = 0.0 | ||
for i, sample in enumerate(padded_samples): | ||
if np.allclose(audio_samples[i], sample): | ||
num_unprocessed_examples += 1 | ||
else: | ||
num_processed_examples += 1 | ||
|
||
self.assertLess(padded_samples.sum(), audio_samples.numpy().sum()) | ||
|
||
def test_padding_variability_perexample(self): | ||
|
||
audio_samples = torch.rand(size=(10, 2, 32000), dtype=torch.float32) | ||
augment = Padding( | ||
min_fraction=0.2, | ||
max_fraction=0.5, | ||
pad_section="start", | ||
p=0.5, | ||
mode="per_example", | ||
p_mode="per_example", | ||
output_type="dict", | ||
) | ||
|
||
padded_samples = augment(audio_samples).samples.numpy() | ||
num_unprocessed_examples = 0.0 | ||
num_processed_examples = 0.0 | ||
for i, sample in enumerate(padded_samples): | ||
if np.allclose(audio_samples[i], sample): | ||
num_unprocessed_examples += 1 | ||
else: | ||
num_processed_examples += 1 | ||
|
||
self.assertEqual(num_processed_examples + num_unprocessed_examples, 10) | ||
self.assertGreater(num_processed_examples, 2) | ||
self.assertLess(num_unprocessed_examples, 8) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") | ||
def test_padding_cuda(self): | ||
|
||
audio_samples = torch.rand( | ||
size=(2, 2, 32000), dtype=torch.float32, device=torch.device("cuda") | ||
) | ||
augment = Padding(min_fraction=0.2, max_fraction=0.5, p=1.0, output_type="dict") | ||
padded_samples = augment(audio_samples).samples | ||
|
||
self.assertEqual(audio_samples.shape, padded_samples.shape) | ||
assert_almost_equal(padded_samples[..., -6400:].cpu().numpy(), np.zeros((2, 2, 6400))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
from typing import Optional | ||
from torch import Tensor | ||
|
||
from ..core.transforms_interface import BaseWaveformTransform | ||
from ..utils.object_dict import ObjectDict | ||
|
||
|
||
class Padding(BaseWaveformTransform): | ||
|
||
supported_modes = {"per_batch", "per_example", "per_channel"} | ||
supports_multichannel = True | ||
requires_sample_rate = False | ||
|
||
supports_target = True | ||
requires_target = False | ||
|
||
def __init__( | ||
self, | ||
min_fraction=0.1, | ||
max_fraction=0.5, | ||
pad_section="end", | ||
mode="per_batch", | ||
p=0.5, | ||
p_mode: Optional[str] = None, | ||
sample_rate: Optional[int] = None, | ||
target_rate: Optional[int] = None, | ||
output_type: Optional[str] = None, | ||
): | ||
super().__init__( | ||
mode=mode, | ||
p=p, | ||
p_mode=p_mode, | ||
sample_rate=sample_rate, | ||
target_rate=target_rate, | ||
output_type=output_type, | ||
) | ||
self.min_fraction = min_fraction | ||
self.max_fraction = max_fraction | ||
self.pad_section = pad_section | ||
if not self.min_fraction >= 0.0: | ||
raise ValueError("minimum fraction should be greater than zero.") | ||
if self.min_fraction > self.max_fraction: | ||
raise ValueError( | ||
"minimum fraction should be less than or equal to maximum fraction." | ||
) | ||
assert self.pad_section in ( | ||
"start", | ||
"end", | ||
), 'pad_section must be "start" or "end"' | ||
|
||
def randomize_parameters( | ||
self, | ||
samples: Tensor = None, | ||
sample_rate: Optional[int] = None, | ||
targets: Optional[Tensor] = None, | ||
target_rate: Optional[int] = None, | ||
): | ||
input_length = samples.shape[-1] | ||
self.transform_parameters["pad_length"] = torch.randint( | ||
int(input_length * self.min_fraction), | ||
int(input_length * self.max_fraction), | ||
(samples.shape[0],), | ||
) | ||
|
||
def apply_transform( | ||
self, | ||
samples: Tensor, | ||
sample_rate: Optional[int] = None, | ||
targets: Optional[int] = None, | ||
target_rate: Optional[int] = None, | ||
) -> ObjectDict: | ||
|
||
for i, index in enumerate(self.transform_parameters["pad_length"]): | ||
if self.pad_section == "start": | ||
samples[i, :, :index] = 0.0 | ||
else: | ||
samples[i, :, -index:] = 0.0 | ||
|
||
return ObjectDict( | ||
samples=samples, | ||
sample_rate=sample_rate, | ||
targets=targets, | ||
target_rate=target_rate, | ||
) |