diff --git a/README.md b/README.md index 53e10c4..f1b326c 100644 --- a/README.md +++ b/README.md @@ -73,7 +73,8 @@ from audio_diffusion_pytorch import Diffusion, LogNormalDistribution diffusion = Diffusion( net=unet, sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0), - sigma_data=0.1 + sigma_data=0.1, + dynamic_threshold=0.95 ) x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples diff --git a/audio_diffusion_pytorch/diffusion.py b/audio_diffusion_pytorch/diffusion.py index f21146c..fc61979 100644 --- a/audio_diffusion_pytorch/diffusion.py +++ b/audio_diffusion_pytorch/diffusion.py @@ -87,13 +87,7 @@ def __init__( self.s_churn = s_churn def step( - self, - x: Tensor, - fn: Callable, - sigma: float, - sigma_next: float, - gamma: float, - clamp: bool = True, + self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float ) -> Tensor: """Algorithm 2 (step)""" # Select temporarily increased noise level @@ -102,12 +96,12 @@ def step( epsilon = self.s_noise * torch.randn_like(x) x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon # Evaluate ∂x/∂sigma at sigma_hat - d = (x_hat - fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat + d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat # Take euler step from sigma_hat to sigma_next x_next = x_hat + (sigma_next - sigma_hat) * d # Second order correction if sigma_next != 0: - model_out_next = fn(x_next, sigma=sigma_next, clamp=clamp) + model_out_next = fn(x_next, sigma=sigma_next) d_prime = (x_next - model_out_next) / sigma_next x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime) return x_next @@ -140,25 +134,18 @@ def __init__(self, rho: float = 1.0): super().__init__() self.rho = rho - def step( - self, - x: Tensor, - fn: Callable, - sigma: float, - sigma_next: float, - clamp: bool = True, - ) -> Tensor: + def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor: # Sigma steps r = self.rho sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2) sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2) sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r # Derivative at sigma (∂x/∂sigma) - d = (x - fn(x, sigma=sigma, clamp=clamp)) / sigma + d = (x - fn(x, sigma=sigma)) / sigma # Denoise to midpoint x_mid = x + d * (sigma_mid - sigma) # Derivative at sigma_mid (∂x_mid/∂sigma_mid) - d_mid = (x_mid - fn(x_mid, sigma=sigma_mid, clamp=clamp)) / sigma_mid + d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid # Denoise to next x = x + d_mid * (sigma_down - sigma) # Add randomness @@ -178,6 +165,11 @@ def forward( """ Diffusion Classes """ +def pad_dims(x: Tensor, ndim: int) -> Tensor: + # Pads additional ndims to the right of the tensor + return x.view(*x.shape, *((1,) * ndim)) + + class Diffusion(nn.Module): """Elucidated Diffusion: https://arxiv.org/abs/2206.00364""" @@ -187,12 +179,14 @@ def __init__( *, sigma_distribution: Distribution, sigma_data: float, # data distribution standard deviation + dynamic_threshold: float = 0.0, ): super().__init__() self.net = net self.sigma_data = sigma_data self.sigma_distribution = sigma_distribution + self.dynamic_threshold = dynamic_threshold def c_skip(self, sigmas: Tensor) -> Tensor: return (self.sigma_data ** 2) / (sigmas ** 2 + self.sigma_data ** 2) @@ -211,7 +205,6 @@ def denoise_fn( x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, - clamp: bool = False, ) -> Tensor: batch, device = x_noisy.shape[0], x_noisy.device @@ -230,9 +223,20 @@ def denoise_fn( x_denoised = ( self.c_skip(sigmas_padded) * x_noisy + self.c_out(sigmas_padded) * x_pred ) - x_denoised = x_denoised.clamp(-1.0, 1) if clamp else x_denoised - return x_denoised + # Dynamic thresholding + if self.dynamic_threshold == 0.0: + return x_denoised.clamp(-1.0, 1.0) + else: + # Find dynamic threshold quantile for each batch + x_flat = rearrange(x_denoised, "b ... -> b (...)") + scale = torch.quantile(x_flat.abs(), self.dynamic_threshold, dim=-1) + # Clamp to a min of 1.0 + scale.clamp_(min=1.0) + # Clamp all values and scale + scale = pad_dims(scale, ndim=x_denoised.ndim - scale.ndim) + x_denoised = x_denoised.clamp(-scale, scale) / scale + return x_denoised def loss_weight(self, sigmas: Tensor) -> Tensor: # Computes weight depending on data distribution @@ -335,12 +339,12 @@ def step( # Add increased noise to mixed value x_hat = x * ~inpaint_mask + inpaint * inpaint_mask + noise # Evaluate ∂x/∂sigma at sigma_hat - d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat + d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat)) / sigma_hat # Take euler step from sigma_hat to sigma_next x_next = x_hat + (sigma_next - sigma_hat) * d # Second order correction if sigma_next != 0: - model_out_next = self.denoise_fn(x_next, sigma=sigma_next, clamp=clamp) + model_out_next = self.denoise_fn(x_next, sigma=sigma_next) d_prime = (x_next - model_out_next) / sigma_next x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime) # Renoise for next resampling step diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 325eb35..f5b9ca8 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -37,6 +37,7 @@ def __init__( use_attention_bottleneck: bool, diffusion_sigma_distribution: Distribution, diffusion_sigma_data: int, + diffusion_dynamic_threshold: float, out_channels: Optional[int] = None, ): super().__init__() @@ -66,6 +67,7 @@ def __init__( net=self.unet, sigma_distribution=diffusion_sigma_distribution, sigma_data=diffusion_sigma_data, + dynamic_threshold=diffusion_dynamic_threshold, ) def forward(self, x: Tensor) -> Tensor: @@ -105,6 +107,7 @@ def __init__(self, *args, **kwargs): use_learned_time_embedding=True, diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0), diffusion_sigma_data=0.1, + diffusion_dynamic_threshold=0.95, ) super().__init__(*args, **{**default_kwargs, **kwargs})