diff --git a/README.md b/README.md index b42c8fd..ee4efe8 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ unet = UNet1d( use_nearest_upsample=False, use_skip_scale=True, use_context_time=True, + use_magnitude_channels=False ) x = torch.randn(3, 1, 2 ** 16) @@ -151,13 +152,20 @@ y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz wit #### Training ```python -from audio_diffusion_pytorch import Diffusion, LogNormalDistribution +from audio_diffusion_pytorch import KDiffusion, VDiffusion, LogNormalDistribution, VDistribution -diffusion = Diffusion( +# Either use KDiffusion +diffusion = KDiffusion( net=unet, sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0), sigma_data=0.1, - dynamic_threshold=0.95 + dynamic_threshold=0.0 +) + +# Or use VDiffusion +diffusion = VDiffusion( + net=unet, + sigma_distribution=VDistribution() ) x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples @@ -239,6 +247,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304] - [x] Add conditional model with classifier-free guidance. - [x] Add option to provide context features mapping. - [x] Add option to change number of (cross) attention blocks. +- [x] Add `VDiffusionn` option. - [ ] Add flash attention. diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 0fc2666..bd2c838 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -7,10 +7,13 @@ Distribution, KarrasSampler, KarrasSchedule, + KDiffusion, LogNormalDistribution, Sampler, Schedule, SpanBySpanComposer, + VDiffusion, + VDistribution, ) from .model import ( AudioDiffusionAutoencoder, diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index e5b4bf1..41ac613 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -1,4 +1,4 @@ -from math import sqrt +from math import atan, pi, sqrt from typing import Any, Callable, Optional, Tuple import torch @@ -9,6 +9,10 @@ from .utils import default, exists +""" +Diffusion Training +""" + """ Distributions """ @@ -23,17 +27,227 @@ def __init__(self, mean: float, std: float): self.std = std def __call__( - self, num_samples, device: torch.device = torch.device("cpu") + self, num_samples: int, device: torch.device = torch.device("cpu") ) -> Tensor: normal = self.mean + self.std * torch.randn((num_samples,), device=device) return normal.exp() +class VDistribution(Distribution): + def __init__( + self, + min_value: float = 0.0, + max_value: float = float("inf"), + sigma_data: float = 1.0, + ): + self.min_value = min_value + self.max_value = max_value + self.sigma_data = sigma_data + + def __call__( + self, num_samples: int, device: torch.device = torch.device("cpu") + ) -> Tensor: + sigma_data = self.sigma_data + min_cdf = atan(self.min_value / sigma_data) * 2 / pi + max_cdf = atan(self.max_value / sigma_data) * 2 / pi + u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf + return torch.tan(u * pi / 2) * sigma_data + + +""" Diffusion Classes """ + + +def pad_dims(x: Tensor, ndim: int) -> Tensor: + # Pads additional ndims to the right of the tensor + return x.view(*x.shape, *((1,) * ndim)) + + +def clip(x: Tensor, dynamic_threshold: float = 0.0): + if dynamic_threshold == 0.0: + return x.clamp(-1.0, 1.0) + else: + # Dynamic thresholding + # Find dynamic threshold quantile for each batch + x_flat = rearrange(x, "b ... -> b (...)") + scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1) + # Clamp to a min of 1.0 + scale.clamp_(min=1.0) + # Clamp all values and scale + scale = pad_dims(scale, ndim=x.ndim - scale.ndim) + x = x.clamp(-scale, scale) / scale + return x + + +def to_batch( + batch_size: int, + device: torch.device, + x: Optional[float] = None, + xs: Optional[Tensor] = None, +) -> Tensor: + assert exists(x) ^ exists(xs), "Either x or xs must be provided" + # If x provided use the same for all batch items + if exists(x): + xs = torch.full(size=(batch_size,), fill_value=x).to(device) + assert exists(xs) + return xs + + +class Diffusion(nn.Module): + + """Base diffusion class""" + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + raise NotImplementedError("Diffusion class missing denoise_fn") + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + raise NotImplementedError("Diffusion class missing forward function") + + +class VDiffusion(Diffusion): + def __init__(self, net: nn.Module, *, sigma_distribution: Distribution): + super().__init__() + self.net = net + self.sigma_distribution = sigma_distribution + + def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: + sigma_data = 1.0 + sigmas = rearrange(sigmas, "b -> b 1 1") + c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) + c_out = -sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 + c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 + return c_skip, c_out, c_in + + def sigma_to_t(self, sigmas: Tensor) -> Tensor: + return sigmas.atan() / pi * 2 + + def t_to_sigma(self, t: Tensor) -> Tensor: + return (t * pi / 2).tan() + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + batch_size, device = x_noisy.shape[0], x_noisy.device + sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) + + # Predict network output and add skip connection + c_skip, c_out, c_in = self.get_scale_weights(sigmas) + x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) + x_denoised = c_skip * x_noisy + c_out * x_pred + return x_denoised + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + batch_size, device = x.shape[0], x.device + + # Sample amount of noise to add for each batch element + sigmas = self.sigma_distribution(num_samples=batch_size, device=device) + sigmas_padded = rearrange(sigmas, "b -> b 1 1") + + # Add noise to input + noise = default(noise, lambda: torch.randn_like(x)) + x_noisy = x + sigmas_padded * noise + + # Compute model output + c_skip, c_out, c_in = self.get_scale_weights(sigmas) + x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs) + + # Compute v-objective target + v_target = (x - c_skip * x_noisy) / c_out + + # Compute loss + loss = F.mse_loss(x_pred, v_target) + return loss + + +class KDiffusion(Diffusion): + """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364""" + + def __init__( + self, + net: nn.Module, + *, + sigma_distribution: Distribution, + sigma_data: float, # data distribution standard deviation + dynamic_threshold: float = 0.0, + ): + super().__init__() + self.net = net + self.sigma_data = sigma_data + self.sigma_distribution = sigma_distribution + self.dynamic_threshold = dynamic_threshold + + def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: + sigma_data = self.sigma_data + c_noise = torch.log(sigmas) * 0.25 + sigmas = rearrange(sigmas, "b -> b 1 1") + c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2) + c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5 + c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5 + return c_skip, c_out, c_in, c_noise + + def denoise_fn( + self, + x_noisy: Tensor, + sigmas: Optional[Tensor] = None, + sigma: Optional[float] = None, + **kwargs, + ) -> Tensor: + batch_size, device = x_noisy.shape[0], x_noisy.device + sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device) + + # Predict network output and add skip connection + c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) + x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) + x_denoised = c_skip * x_noisy + c_out * x_pred + + # Clips in [-1,1] range, with dynamic thresholding if provided + return clip(x_denoised, dynamic_threshold=self.dynamic_threshold) + + def loss_weight(self, sigmas: Tensor) -> Tensor: + # Computes weight depending on data distribution + return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2 + + def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: + batch_size, device = x.shape[0], x.device + + # Sample amount of noise to add for each batch element + sigmas = self.sigma_distribution(num_samples=batch_size, device=device) + sigmas_padded = rearrange(sigmas, "b -> b 1 1") + + # Add noise to input + noise = default(noise, lambda: torch.randn_like(x)) + x_noisy = x + sigmas_padded * noise + + # Compute denoised values + x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) + + # Compute weighted loss + losses = F.mse_loss(x_denoised, x, reduction="none") + losses = reduce(losses, "b ... -> b", "mean") + losses = losses * self.loss_weight(sigmas) + loss = losses.mean() + + return loss + + +""" +Diffusion Sampling +""" + """ Schedules """ class Schedule(nn.Module): - """Interface used by different schedules""" + """Interface used by different sampling schedules""" def forward(self, num_steps: int, device: torch.device) -> Tensor: raise NotImplementedError() @@ -62,8 +276,6 @@ def forward(self, num_steps: int, device: Any) -> Tensor: """ Samplers """ -""" Many methods inspired by https://github.com/crowsonkb/k-diffusion/ """ - class Sampler(nn.Module): def forward( @@ -229,104 +441,7 @@ def inpaint( return source * mask + x * ~mask -""" Diffusion Classes """ - - -def pad_dims(x: Tensor, ndim: int) -> Tensor: - # Pads additional ndims to the right of the tensor - return x.view(*x.shape, *((1,) * ndim)) - - -class Diffusion(nn.Module): - """Elucidated Diffusion: https://arxiv.org/abs/2206.00364""" - - def __init__( - self, - net: nn.Module, - *, - sigma_distribution: Distribution, - sigma_data: float, # data distribution standard deviation - dynamic_threshold: float = 0.0, - ): - super().__init__() - - self.net = net - self.sigma_data = sigma_data - self.sigma_distribution = sigma_distribution - self.dynamic_threshold = dynamic_threshold - - def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]: - sigma_data = self.sigma_data - sigmas_padded = rearrange(sigmas, "b -> b 1 1") - c_skip = (sigma_data ** 2) / (sigmas_padded ** 2 + sigma_data ** 2) - c_out = ( - sigmas_padded * sigma_data * (sigma_data ** 2 + sigmas_padded ** 2) ** -0.5 - ) - c_in = (sigmas_padded ** 2 + sigma_data ** 2) ** -0.5 - c_noise = torch.log(sigmas) * 0.25 - return c_skip, c_out, c_in, c_noise - - def denoise_fn( - self, - x_noisy: Tensor, - sigmas: Optional[Tensor] = None, - sigma: Optional[float] = None, - **kwargs, - ) -> Tensor: - batch, device = x_noisy.shape[0], x_noisy.device - - assert exists(sigmas) ^ exists(sigma), "Either sigmas or sigma must be provided" - - # If sigma provided use the same for all batch items (used for sampling) - if exists(sigma): - sigmas = torch.full(size=(batch,), fill_value=sigma).to(device) - - assert exists(sigmas) - - # Predict network output and add skip connection - c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas) - x_pred = self.net(c_in * x_noisy, c_noise, **kwargs) - x_denoised = c_skip * x_noisy + c_out * x_pred - - # Dynamic thresholding - if self.dynamic_threshold == 0.0: - return x_denoised.clamp(-1.0, 1.0) - else: - # Find dynamic threshold quantile for each batch - x_flat = rearrange(x_denoised, "b ... -> b (...)") - scale = torch.quantile(x_flat.abs(), self.dynamic_threshold, dim=-1) - # Clamp to a min of 1.0 - scale.clamp_(min=1.0) - # Clamp all values and scale - scale = pad_dims(scale, ndim=x_denoised.ndim - scale.ndim) - x_denoised = x_denoised.clamp(-scale, scale) / scale - return x_denoised - - def loss_weight(self, sigmas: Tensor) -> Tensor: - # Computes weight depending on data distribution - return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2 - - def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor: - batch, device = x.shape[0], x.device - - # Sample amount of noise to add for each batch element - sigmas = self.sigma_distribution(num_samples=batch, device=device) - sigmas_padded = rearrange(sigmas, "b -> b 1 1") - - # Add noise to input - noise = default(noise, lambda: torch.randn_like(x)) - x_noisy = x + sigmas_padded * noise - - # Compute denoised values - x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs) - - # Compute weighted loss - losses = F.mse_loss(x_denoised, x, reduction="none") - losses = reduce(losses, "b ... -> b", "mean") - losses = losses * self.loss_weight(sigmas) - loss = losses.mean() - - return loss +""" Main Classes """ class DiffusionSampler(nn.Module): diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index d6ad9d7..c6bc1e1 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -4,14 +4,14 @@ from torch import Tensor, nn from .diffusion import ( - ADPM2Sampler, - Diffusion, + AEulerSampler, DiffusionSampler, Distribution, KarrasSchedule, - LogNormalDistribution, Sampler, Schedule, + VDiffusion, + VDistribution, ) from .modules import ( Bottleneck, @@ -31,8 +31,6 @@ class Model1d(nn.Module): def __init__( self, diffusion_sigma_distribution: Distribution, - diffusion_sigma_data: int, - diffusion_dynamic_threshold: float, use_classifier_free_guidance: bool = False, **kwargs ): @@ -42,11 +40,8 @@ def __init__( self.unet = UNet(**kwargs) - self.diffusion = Diffusion( - net=self.unet, - sigma_distribution=diffusion_sigma_distribution, - sigma_data=diffusion_sigma_data, - dynamic_threshold=diffusion_dynamic_threshold, + self.diffusion = VDiffusion( + net=self.unet, sigma_distribution=diffusion_sigma_distribution ) def forward(self, x: Tensor, **kwargs) -> Tensor: @@ -245,16 +240,14 @@ def get_default_model_kwargs(): use_skip_scale=True, use_context_time=True, use_magnitude_channels=False, - diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0), - diffusion_sigma_data=0.1, - diffusion_dynamic_threshold=0.0, + diffusion_sigma_distribution=VDistribution(), ) def get_default_sampling_kwargs(): return dict( sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), - sampler=ADPM2Sampler(rho=1.0), + sampler=AEulerSampler(), ) diff --git a/setup.py b/setup.py index 059afeb..1edd201 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.59", + version="0.0.60", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",