Skip to content

Commit

Permalink
feat: add variational bottleneck, add option to change diffusion type
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 13, 2022
1 parent 68da808 commit d16dfa7
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 22 deletions.
8 changes: 7 additions & 1 deletion audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,10 @@
DiffusionUpsampler1d,
Model1d,
)
from .modules import AutoEncoder1d, MultiEncoder1d, UNet1d, UNetConditional1d
from .modules import (
AutoEncoder1d,
MultiEncoder1d,
UNet1d,
UNetConditional1d,
Variational,
)
30 changes: 16 additions & 14 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

from .diffusion import (
AEulerSampler,
Diffusion,
DiffusionSampler,
Distribution,
KarrasSchedule,
KDiffusion,
Sampler,
Schedule,
VDiffusion,
Expand All @@ -20,7 +21,7 @@
UNet1d,
UNetConditional1d,
)
from .utils import default, downsample, exists, to_list, upsample
from .utils import default, downsample, exists, groupby_kwargs_prefix, to_list, upsample

"""
Diffusion Classes (generic for 1d data)
Expand All @@ -29,20 +30,20 @@

class Model1d(nn.Module):
def __init__(
self,
diffusion_sigma_distribution: Distribution,
use_classifier_free_guidance: bool = False,
**kwargs
self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs
):
super().__init__()
diffusion_kwargs, kwargs = groupby_kwargs_prefix("diffusion_", kwargs)

UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d

self.unet = UNet(**kwargs)

self.diffusion = VDiffusion(
net=self.unet, sigma_distribution=diffusion_sigma_distribution
)
if diffusion_type == "v":
self.diffusion: Diffusion = VDiffusion(net=self.unet, **diffusion_kwargs)
elif diffusion_type == "k":
self.diffusion = KDiffusion(net=self.unet, **diffusion_kwargs)
else:
raise ValueError(f"diffusion_type must be v or k, found {diffusion_type}")

def forward(self, x: Tensor, **kwargs) -> Tensor:
return self.diffusion(x, **kwargs)
Expand All @@ -53,7 +54,7 @@ def sample(
num_steps: int,
sigma_schedule: Schedule,
sampler: Sampler,
**kwargs
**kwargs,
) -> Tensor:
diffusion_sampler = DiffusionSampler(
diffusion=self.diffusion,
Expand All @@ -71,7 +72,7 @@ def __init__(
factor: Union[int, Sequence[int]],
factor_features: Optional[int] = None,
*args,
**kwargs
**kwargs,
):
self.factors = to_list(factor)
self.use_conditioning = exists(factor_features)
Expand Down Expand Up @@ -144,7 +145,7 @@ def __init__(
bottleneck: Optional[Bottleneck] = None,
encoder_num_blocks: Optional[Sequence[int]] = None,
encoder_out_layers: int = 0,
**kwargs
**kwargs,
):
self.in_channels = in_channels
encoder_num_blocks = default(encoder_num_blocks, num_blocks)
Expand Down Expand Up @@ -240,6 +241,7 @@ def get_default_model_kwargs():
use_skip_scale=True,
use_context_time=True,
use_magnitude_channels=False,
diffusion_type="v",
diffusion_sigma_distribution=VDistribution(),
)

Expand Down Expand Up @@ -289,7 +291,7 @@ def __init__(
embedding_features: int,
embedding_max_length: int,
embedding_mask_proba: float = 0.1,
**kwargs
**kwargs,
):
self.embedding_mask_proba = embedding_mask_proba
default_kwargs = dict(
Expand Down
56 changes: 51 additions & 5 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,10 +1138,46 @@ def forward( # type: ignore
class Bottleneck(nn.Module):
"""Bottleneck interface (subclass can be provided to (Diffusion)Autoencoder1d)"""

def forward(self, x: Tensor) -> Tuple[Tensor, Any]:
def forward(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
raise NotImplementedError()


def gaussian_sample(mean: Tensor, logvar: Tensor) -> Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
sample = mean + std * eps
return sample


def kl_loss(mean: Tensor, logvar: Tensor) -> Tensor:
losses = mean ** 2 + logvar.exp() - logvar - 1
loss = reduce(losses, "b ... -> 1", "mean").item()
return loss


class Variational(Bottleneck):
def __init__(self, channels: int, loss_weight: float = 1.0):
super().__init__()
self.loss_weight = loss_weight
self.to_mean_and_logvar = Conv1d(
in_channels=channels,
out_channels=channels * 2,
kernel_size=1,
)

def forward(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
mean_and_logvar = self.to_mean_and_logvar(x)
mean, logvar = torch.chunk(mean_and_logvar, chunks=2, dim=1)
logvar = torch.clamp(logvar, -30.0, 20.0)
out = gaussian_sample(mean, logvar)
loss = kl_loss(mean, logvar) * self.loss_weight
return (out, dict(loss=loss)) if with_info else out


class AutoEncoder1d(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -1208,18 +1244,28 @@ def __init__(
factor=patch_factor,
)

def encode(
def forward(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
z, info = self.encode(x, with_info=True)
y = self.decode(z)
return (y, info) if with_info else y

def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
xs = []
x = self.to_in(x)
for downsample in self.downsamples:
x = downsample(x)
xs += [x]
info = dict(xs=xs)

if exists(self.bottleneck):
x, info = self.bottleneck(x)
return (x, info) if with_info else x
return x
x, info_bottleneck = self.bottleneck(x, with_info=True)
info = {**info, **info_bottleneck}

return (x, info) if with_info else x

def decode(self, x: Tensor) -> Tensor:
for upsample in self.upsamples:
Expand Down
21 changes: 20 additions & 1 deletion audio_diffusion_pytorch/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
from functools import reduce
from inspect import isfunction
from typing import Callable, List, Optional, Sequence, TypeVar, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -42,6 +42,25 @@ def prod(vals: Sequence[int]) -> int:
return reduce(lambda x, y: x * y, vals)


"""
Kwargs Utils
"""


def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
return_dicts: Tuple[Dict, Dict] = ({}, {})
for key in d.keys():
no_prefix = int(not key.startswith(prefix))
return_dicts[no_prefix][key] = d[key]
return return_dicts


def groupby_kwargs_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
return kwargs_no_prefix, kwargs


"""
DSP Utils
"""
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.60",
version="0.0.61",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit d16dfa7

Please sign in to comment.