Skip to content

Commit

Permalink
feat: option to condition upsampling with used factor
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 6, 2022
1 parent aaaa699 commit e4c118f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
48 changes: 36 additions & 12 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
Sampler,
Schedule,
)
from .modules import Bottleneck, MultiEncoder1d, UNet1d, UNetConditional1d
from .modules import (
Bottleneck,
MultiEncoder1d,
SinusoidalEmbedding,
UNet1d,
UNetConditional1d,
)
from .utils import default, downsample, exists, to_list, upsample

"""
Expand Down Expand Up @@ -65,46 +71,64 @@ def sample(

class DiffusionUpsampler1d(Model1d):
def __init__(
self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs
self,
in_channels: int,
factor: Union[int, Sequence[int]],
factor_features: Optional[int] = None,
*args,
**kwargs
):
self.factors = to_list(factor)
self.use_conditioning = exists(factor_features)

default_kwargs = dict(
in_channels=in_channels,
context_channels=[in_channels],
context_features=factor_features if self.use_conditioning else None,
)
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore

def random_reupsample(self, x: Tensor) -> Tensor:
if self.use_conditioning:
assert exists(factor_features)
self.to_features = SinusoidalEmbedding(dim=factor_features)

def random_reupsample(self, x: Tensor) -> Tuple[Tensor, Tensor]:
batch_size, factors = x.shape[0], self.factors
# Pick random factor for each batch element
factor_batch_idx = torch.randint(0, len(factors), (batch_size,))
random_factors = torch.randint(0, len(factors), (batch_size,))
x = x.clone()

for i, factor in enumerate(factors):
# Pick random items with current factor, skip if 0
n = torch.count_nonzero(factor_batch_idx == i)
n = torch.count_nonzero(random_factors == i)
if n > 0:
waveforms = x[factor_batch_idx == i]
waveforms = x[random_factors == i]
# Downsample and reupsample items
downsampled = downsample(waveforms, factor=factor)
reupsampled = upsample(downsampled, factor=factor)
# Save reupsampled version in place
x[factor_batch_idx == i] = reupsampled
return x
x[random_factors == i] = reupsampled
return x, random_factors

def forward(self, x: Tensor, **kwargs) -> Tensor:
channels = self.random_reupsample(x)
return self.diffusion(x, channels_list=[channels], **kwargs)
channels, factors = self.random_reupsample(x)
features = self.to_features(factors) if self.use_conditioning else None
return self.diffusion(x, channels_list=[channels], features=features, **kwargs)

def sample( # type: ignore
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
):
# Either user provides factor or we pick the first
batch_size, device = undersampled.shape[0], undersampled.device
factor = default(factor, self.factors[0])
# Upsample channels
# Upsample channels by interpolation
channels = upsample(undersampled, factor=factor)
# Compute features if conditioning on factor
factors = torch.tensor([factor] * batch_size, device=device)
features = self.to_features(factors) if self.use_conditioning else None
# Diffuse upsampled
noise = torch.randn_like(channels)
default_kwargs = dict(channels_list=[channels])
default_kwargs = dict(channels_list=[channels], features=features)
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore


Expand Down
14 changes: 14 additions & 0 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
from math import pi
from typing import Any, List, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -488,6 +489,19 @@ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
"""


class SinusoidalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim

def forward(self, x: Tensor) -> Tensor:
device, half_dim = x.device, self.dim // 2
emb = torch.tensor(math.log(10000) / (half_dim - 1), device=device)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
return torch.cat((emb.sin(), emb.cos()), dim=-1)


class LearnedPositionalEmbedding(nn.Module):
"""Used for continuous time"""

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.57",
version="0.0.58",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit e4c118f

Please sign in to comment.