Skip to content

Commit

Permalink
Merge pull request asteroid-team#73 from asteroid-team/ij/use-new-rfft
Browse files Browse the repository at this point in the history
Use torch.fft.rfft instead of the deprecated torch.rfft when possible
  • Loading branch information
iver56 authored Dec 18, 2020
2 parents e1ee448 + 2aa6bfa commit 8e1a348
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 29 deletions.
File renamed without changes.
36 changes: 7 additions & 29 deletions torch_audiomentations/utils/convolution.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch

from torch_audiomentations.utils.fft import rfft, irfft

_NEXT_FAST_LEN = {}


Expand Down Expand Up @@ -35,24 +37,7 @@ def next_fast_len(size):
next_size += 1


def _complex_mul(a, b):
"""
Note: This function was originally copied from the https://github.com/pyro-ppl/pyro
repository, where the license was Apache 2.0. Any modifications to the original code can be
found at https://github.com/asteroid-team/torch-audiomentations/commits
:param a:
:param b:
:return:
"""

ar, ai = a.unbind(-1)
br, bi = b.unbind(-1)
return torch.stack([ar * br - ai * bi, ar * bi + ai * br], dim=-1)


def convolve(signal, kernel, mode="full", method="fft"):
def convolve(signal, kernel, mode="full"):
"""
Computes the 1-d convolution of signal by kernel using FFTs.
The two arguments should have the same rightmost dim, but may otherwise be
Expand All @@ -72,9 +57,6 @@ def convolve(signal, kernel, mode="full", method="fft"):
``max(m, n)`` if mode is 'same'.
:rtype torch.Tensor:
"""
if method != "fft":
raise NotImplementedError('Only method="fft" is supported')

m = signal.size(-1)
n = kernel.size(-1)
if mode == "full":
Expand All @@ -90,14 +72,10 @@ def convolve(signal, kernel, mode="full", method="fft"):
padded_size = m + n - 1
# Round up for cheaper fft.
fast_ftt_size = next_fast_len(padded_size)
f_signal = torch.rfft(
torch.nn.functional.pad(signal, (0, fast_ftt_size - m)), 1, onesided=False
)
f_kernel = torch.rfft(
torch.nn.functional.pad(kernel, (0, fast_ftt_size - n)), 1, onesided=False
)
f_result = _complex_mul(f_signal, f_kernel)
result = torch.irfft(f_result, 1, onesided=False)
f_signal = rfft(signal, n=fast_ftt_size)
f_kernel = rfft(kernel, n=fast_ftt_size)
f_result = f_signal * f_kernel
result = irfft(f_result, n=fast_ftt_size)

start_idx = (padded_size - truncate) // 2
return result[..., start_idx : start_idx + truncate]
31 changes: 31 additions & 0 deletions torch_audiomentations/utils/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

# Note: Code in this module has been copied from the https://github.com/pyro-ppl/pyro
# repository, where the license was Apache 2.0. Any modifications to the original code can be
# found at https://github.com/asteroid-team/torch-audiomentations/commits

import torch

try:
# This works in PyTorch>=1.7
from torch.fft import irfft, rfft
except ModuleNotFoundError:
# This works in PyTorch<=1.6
def rfft(input, n=None):
if n is not None:
m = input.size(-1)
if n > m:
input = torch.nn.functional.pad(input, (0, n - m))
elif n < m:
input = input[..., :n]
return torch.view_as_complex(torch.rfft(input, 1))

def irfft(input, n=None):
if torch.is_complex(input):
input = torch.view_as_real(input)
else:
input = torch.nn.functional.pad(input[..., None], (0, 1))
if n is None:
n = 2 * (input.size(-1) - 1)
return torch.irfft(input, 1, signal_sizes=(n,))

0 comments on commit 8e1a348

Please sign in to comment.