diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index bd2c838..d211371 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -24,4 +24,10 @@ DiffusionUpsampler1d, Model1d, ) -from .modules import AutoEncoder1d, MultiEncoder1d, UNet1d, UNetConditional1d +from .modules import ( + AutoEncoder1d, + MultiEncoder1d, + UNet1d, + UNetConditional1d, + Variational, +) diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index c6bc1e1..f3c3f4f 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -5,9 +5,10 @@ from .diffusion import ( AEulerSampler, + Diffusion, DiffusionSampler, - Distribution, KarrasSchedule, + KDiffusion, Sampler, Schedule, VDiffusion, @@ -20,7 +21,7 @@ UNet1d, UNetConditional1d, ) -from .utils import default, downsample, exists, to_list, upsample +from .utils import default, downsample, exists, groupby_kwargs_prefix, to_list, upsample """ Diffusion Classes (generic for 1d data) @@ -29,20 +30,20 @@ class Model1d(nn.Module): def __init__( - self, - diffusion_sigma_distribution: Distribution, - use_classifier_free_guidance: bool = False, - **kwargs + self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs ): super().__init__() + diffusion_kwargs, kwargs = groupby_kwargs_prefix("diffusion_", kwargs) UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d - self.unet = UNet(**kwargs) - self.diffusion = VDiffusion( - net=self.unet, sigma_distribution=diffusion_sigma_distribution - ) + if diffusion_type == "v": + self.diffusion: Diffusion = VDiffusion(net=self.unet, **diffusion_kwargs) + elif diffusion_type == "k": + self.diffusion = KDiffusion(net=self.unet, **diffusion_kwargs) + else: + raise ValueError(f"diffusion_type must be v or k, found {diffusion_type}") def forward(self, x: Tensor, **kwargs) -> Tensor: return self.diffusion(x, **kwargs) @@ -53,7 +54,7 @@ def sample( num_steps: int, sigma_schedule: Schedule, sampler: Sampler, - **kwargs + **kwargs, ) -> Tensor: diffusion_sampler = DiffusionSampler( diffusion=self.diffusion, @@ -71,7 +72,7 @@ def __init__( factor: Union[int, Sequence[int]], factor_features: Optional[int] = None, *args, - **kwargs + **kwargs, ): self.factors = to_list(factor) self.use_conditioning = exists(factor_features) @@ -144,7 +145,7 @@ def __init__( bottleneck: Optional[Bottleneck] = None, encoder_num_blocks: Optional[Sequence[int]] = None, encoder_out_layers: int = 0, - **kwargs + **kwargs, ): self.in_channels = in_channels encoder_num_blocks = default(encoder_num_blocks, num_blocks) @@ -240,6 +241,7 @@ def get_default_model_kwargs(): use_skip_scale=True, use_context_time=True, use_magnitude_channels=False, + diffusion_type="v", diffusion_sigma_distribution=VDistribution(), ) @@ -289,7 +291,7 @@ def __init__( embedding_features: int, embedding_max_length: int, embedding_mask_proba: float = 0.1, - **kwargs + **kwargs, ): self.embedding_mask_proba = embedding_mask_proba default_kwargs = dict( diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index e3305e7..3d798c9 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -1138,10 +1138,46 @@ def forward( # type: ignore class Bottleneck(nn.Module): """Bottleneck interface (subclass can be provided to (Diffusion)Autoencoder1d)""" - def forward(self, x: Tensor) -> Tuple[Tensor, Any]: + def forward( + self, x: Tensor, with_info: bool = False + ) -> Union[Tensor, Tuple[Tensor, Any]]: raise NotImplementedError() +def gaussian_sample(mean: Tensor, logvar: Tensor) -> Tensor: + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + sample = mean + std * eps + return sample + + +def kl_loss(mean: Tensor, logvar: Tensor) -> Tensor: + losses = mean ** 2 + logvar.exp() - logvar - 1 + loss = reduce(losses, "b ... -> 1", "mean").item() + return loss + + +class Variational(Bottleneck): + def __init__(self, channels: int, loss_weight: float = 1.0): + super().__init__() + self.loss_weight = loss_weight + self.to_mean_and_logvar = Conv1d( + in_channels=channels, + out_channels=channels * 2, + kernel_size=1, + ) + + def forward( + self, x: Tensor, with_info: bool = False + ) -> Union[Tensor, Tuple[Tensor, Any]]: + mean_and_logvar = self.to_mean_and_logvar(x) + mean, logvar = torch.chunk(mean_and_logvar, chunks=2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + out = gaussian_sample(mean, logvar) + loss = kl_loss(mean, logvar) * self.loss_weight + return (out, dict(loss=loss)) if with_info else out + + class AutoEncoder1d(nn.Module): def __init__( self, @@ -1208,18 +1244,28 @@ def __init__( factor=patch_factor, ) - def encode( + def forward( self, x: Tensor, with_info: bool = False ) -> Union[Tensor, Tuple[Tensor, Any]]: + z, info = self.encode(x, with_info=True) + y = self.decode(z) + return (y, info) if with_info else y + def encode( + self, x: Tensor, with_info: bool = False + ) -> Union[Tensor, Tuple[Tensor, Any]]: + xs = [] x = self.to_in(x) for downsample in self.downsamples: x = downsample(x) + xs += [x] + info = dict(xs=xs) if exists(self.bottleneck): - x, info = self.bottleneck(x) - return (x, info) if with_info else x - return x + x, info_bottleneck = self.bottleneck(x, with_info=True) + info = {**info, **info_bottleneck} + + return (x, info) if with_info else x def decode(self, x: Tensor) -> Tensor: for upsample in self.upsamples: diff --git a/audio_diffusion_pytorch/utils.py b/audio_diffusion_pytorch/utils.py index 791da02..4240211 100644 --- a/audio_diffusion_pytorch/utils.py +++ b/audio_diffusion_pytorch/utils.py @@ -1,7 +1,7 @@ import math from functools import reduce from inspect import isfunction -from typing import Callable, List, Optional, Sequence, TypeVar, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union import torch import torch.nn.functional as F @@ -42,6 +42,25 @@ def prod(vals: Sequence[int]) -> int: return reduce(lambda x, y: x * y, vals) +""" +Kwargs Utils +""" + + +def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + return_dicts: Tuple[Dict, Dict] = ({}, {}) + for key in d.keys(): + no_prefix = int(not key.startswith(prefix)) + return_dicts[no_prefix][key] = d[key] + return return_dicts + + +def groupby_kwargs_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]: + kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d) + kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()} + return kwargs_no_prefix, kwargs + + """ DSP Utils """ diff --git a/setup.py b/setup.py index 1edd201..3eccf6a 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.60", + version="0.0.61", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",