Skip to content

Commit

Permalink
feat: update upsampler with proper resampling method, randomize in-ba…
Browse files Browse the repository at this point in the history
…tch resampling with multiple factors
  • Loading branch information
flavioschneider committed Sep 27, 2022
1 parent da94c00 commit 6992cc9
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
37 changes: 24 additions & 13 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from typing import Any, Optional, Sequence, Tuple, Union

import torch
Expand All @@ -15,7 +14,7 @@
Schedule,
)
from .modules import Bottleneck, MultiEncoder1d, UNet1d, UNetConditional1d
from .utils import default, exists, to_list
from .utils import default, downsample, exists, to_list, upsample

"""
Diffusion Classes (generic for 1d data)
Expand Down Expand Up @@ -68,29 +67,41 @@ class DiffusionUpsampler1d(Model1d):
def __init__(
self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs
):
self.factor = to_list(factor)
self.factors = to_list(factor)
default_kwargs = dict(
in_channels=in_channels,
context_channels=[in_channels],
)
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore

def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor:
# Either user provides factor or we pick one at random
factor = default(factor, random.choice(self.factor))
# Downsample by picking every `factor` item
downsampled = x[:, :, ::factor]
# Upsample by interleaving to get context
channels = torch.repeat_interleave(downsampled, repeats=factor, dim=2)
def random_reupsample(self, x: Tensor) -> Tensor:
batch_size, factors = x.shape[0], self.factors
# Pick random factor for each batch element
factor_batch_idx = torch.randint(0, len(factors), (batch_size,))

for i, factor in enumerate(factors):
# Pick random items with current factor, skip if 0
n = torch.count_nonzero(factor_batch_idx == i)
if n > 0:
waveforms = x[factor_batch_idx == i]
# Downsample and reupsample items
downsampled = downsample(waveforms, factor=factor)
reupsampled = upsample(downsampled, factor=factor)
# Save reupsampled version in place
x[factor_batch_idx == i] = reupsampled
return x

def forward(self, x: Tensor, **kwargs) -> Tensor:
channels = self.random_reupsample(x)
return self.diffusion(x, channels_list=[channels], **kwargs)

def sample( # type: ignore
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
):
# Either user provides factor or we pick the first
factor = default(factor, self.factor[0])
# Upsample channels by interleaving
channels = torch.repeat_interleave(undersampled, repeats=factor, dim=2)
factor = default(factor, self.factors[0])
# Upsample channels
channels = upsample(undersampled, factor=factor)
noise = torch.randn_like(channels)
default_kwargs = dict(channels_list=[channels])
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
Expand Down
48 changes: 48 additions & 0 deletions audio_diffusion_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import math
from functools import reduce
from inspect import isfunction
from typing import Callable, List, Optional, Sequence, TypeVar, Union

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from typing_extensions import TypeGuard

T = TypeVar("T")
Expand Down Expand Up @@ -35,3 +40,46 @@ def to_list(val: Union[T, Sequence[T]]) -> List[T]:

def prod(vals: Sequence[int]) -> int:
return reduce(lambda x, y: x * y, vals)


"""
DSP Utils
"""


def resample(
waveforms: Tensor,
factor_in: int,
factor_out: int,
rolloff: float = 0.99,
lowpass_filter_width: int = 6,
) -> Tensor:
"""Resamples a waveform using sinc interpolation, adapted from torchaudio"""
b, _, length = waveforms.shape
length_target = int(factor_out * length / factor_in)
d = dict(device=waveforms.device, dtype=waveforms.dtype)

base_factor = min(factor_in, factor_out) * rolloff
width = math.ceil(lowpass_filter_width * factor_in / base_factor)
idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * math.pi

window = torch.cos(t / lowpass_filter_width / 2) ** 2
scale = base_factor / factor_in
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
kernels *= window * scale

waveforms = rearrange(waveforms, "b c t -> (b c) t")
waveforms = F.pad(waveforms, (width, width + factor_in))
resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in)
resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b)
return resampled[..., :length_target]


def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
return resample(waveforms, factor_in=factor, factor_out=1, **kwargs)


def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.48",
version="0.0.49",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 6992cc9

Please sign in to comment.