Skip to content

Commit

Permalink
Fix a bug where BandPassFilter didn't work on GPU
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Apr 21, 2022
1 parent ccd8cf8 commit 818e561
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
23 changes: 23 additions & 0 deletions tests/test_band_pass_filter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import pytest
import torch

from torch_audiomentations import BandPassFilter
Expand All @@ -23,3 +24,25 @@ def test_band_pass_filter(self):
).samples.numpy()
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
def test_band_pass_filter_cuda(self):
samples = np.array(
[
[[0.75, 0.5, -0.25, -0.125, 0.0], [0.65, 0.5, -0.25, -0.125, 0.0]],
[[0.3, 0.5, -0.25, -0.125, 0.0], [0.9, 0.5, -0.25, -0.125, 0.0]],
[[0.9, 0.5, -0.25, -1.06, 0.0], [0.9, 0.5, -0.25, -1.12, 0.0]],
],
dtype=np.float32,
)
sample_rate = 16000

augment = BandPassFilter(p=1.0, output_type="dict")
for _ in range(20):
processed_samples = (
augment(samples=torch.from_numpy(samples).cuda(), sample_rate=sample_rate)
.samples.cpu()
.numpy()
)
assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32
7 changes: 6 additions & 1 deletion torch_audiomentations/augmentations/band_pass_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ def get_dist(min_freq, max_freq):
)

bandwidth_dist = torch.distributions.Uniform(
low=self.min_bandwidth_fraction, high=self.max_bandwidth_fraction,
low=torch.tensor(
self.min_bandwidth_fraction, dtype=torch.float32, device=samples.device
),
high=torch.tensor(
self.max_bandwidth_fraction, dtype=torch.float32, device=samples.device
)
)
self.transform_parameters["bandwidth"] = bandwidth_dist.sample(
sample_shape=(batch_size,)
Expand Down

0 comments on commit 818e561

Please sign in to comment.