Skip to content

Commit

Permalink
feat: add dynamic thresholding
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Aug 11, 2022
1 parent 661c392 commit a5f1069
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 25 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ from audio_diffusion_pytorch import Diffusion, LogNormalDistribution
diffusion = Diffusion(
net=unet,
sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
sigma_data=0.1
sigma_data=0.1,
dynamic_threshold=0.95
)

x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
Expand Down
52 changes: 28 additions & 24 deletions audio_diffusion_pytorch/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,7 @@ def __init__(
self.s_churn = s_churn

def step(
self,
x: Tensor,
fn: Callable,
sigma: float,
sigma_next: float,
gamma: float,
clamp: bool = True,
self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
) -> Tensor:
"""Algorithm 2 (step)"""
# Select temporarily increased noise level
Expand All @@ -102,12 +96,12 @@ def step(
epsilon = self.s_noise * torch.randn_like(x)
x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
# Evaluate ∂x/∂sigma at sigma_hat
d = (x_hat - fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
# Take euler step from sigma_hat to sigma_next
x_next = x_hat + (sigma_next - sigma_hat) * d
# Second order correction
if sigma_next != 0:
model_out_next = fn(x_next, sigma=sigma_next, clamp=clamp)
model_out_next = fn(x_next, sigma=sigma_next)
d_prime = (x_next - model_out_next) / sigma_next
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
return x_next
Expand Down Expand Up @@ -140,25 +134,18 @@ def __init__(self, rho: float = 1.0):
super().__init__()
self.rho = rho

def step(
self,
x: Tensor,
fn: Callable,
sigma: float,
sigma_next: float,
clamp: bool = True,
) -> Tensor:
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
# Sigma steps
r = self.rho
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
# Derivative at sigma (∂x/∂sigma)
d = (x - fn(x, sigma=sigma, clamp=clamp)) / sigma
d = (x - fn(x, sigma=sigma)) / sigma
# Denoise to midpoint
x_mid = x + d * (sigma_mid - sigma)
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid, clamp=clamp)) / sigma_mid
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
# Denoise to next
x = x + d_mid * (sigma_down - sigma)
# Add randomness
Expand All @@ -178,6 +165,11 @@ def forward(
""" Diffusion Classes """


def pad_dims(x: Tensor, ndim: int) -> Tensor:
# Pads additional ndims to the right of the tensor
return x.view(*x.shape, *((1,) * ndim))


class Diffusion(nn.Module):
"""Elucidated Diffusion: https://arxiv.org/abs/2206.00364"""

Expand All @@ -187,12 +179,14 @@ def __init__(
*,
sigma_distribution: Distribution,
sigma_data: float, # data distribution standard deviation
dynamic_threshold: float = 0.0,
):
super().__init__()

self.net = net
self.sigma_data = sigma_data
self.sigma_distribution = sigma_distribution
self.dynamic_threshold = dynamic_threshold

def c_skip(self, sigmas: Tensor) -> Tensor:
return (self.sigma_data ** 2) / (sigmas ** 2 + self.sigma_data ** 2)
Expand All @@ -211,7 +205,6 @@ def denoise_fn(
x_noisy: Tensor,
sigmas: Optional[Tensor] = None,
sigma: Optional[float] = None,
clamp: bool = False,
) -> Tensor:
batch, device = x_noisy.shape[0], x_noisy.device

Expand All @@ -230,9 +223,20 @@ def denoise_fn(
x_denoised = (
self.c_skip(sigmas_padded) * x_noisy + self.c_out(sigmas_padded) * x_pred
)
x_denoised = x_denoised.clamp(-1.0, 1) if clamp else x_denoised

return x_denoised
# Dynamic thresholding
if self.dynamic_threshold == 0.0:
return x_denoised.clamp(-1.0, 1.0)
else:
# Find dynamic threshold quantile for each batch
x_flat = rearrange(x_denoised, "b ... -> b (...)")
scale = torch.quantile(x_flat.abs(), self.dynamic_threshold, dim=-1)
# Clamp to a min of 1.0
scale.clamp_(min=1.0)
# Clamp all values and scale
scale = pad_dims(scale, ndim=x_denoised.ndim - scale.ndim)
x_denoised = x_denoised.clamp(-scale, scale) / scale
return x_denoised

def loss_weight(self, sigmas: Tensor) -> Tensor:
# Computes weight depending on data distribution
Expand Down Expand Up @@ -335,12 +339,12 @@ def step(
# Add increased noise to mixed value
x_hat = x * ~inpaint_mask + inpaint * inpaint_mask + noise
# Evaluate ∂x/∂sigma at sigma_hat
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat)) / sigma_hat
# Take euler step from sigma_hat to sigma_next
x_next = x_hat + (sigma_next - sigma_hat) * d
# Second order correction
if sigma_next != 0:
model_out_next = self.denoise_fn(x_next, sigma=sigma_next, clamp=clamp)
model_out_next = self.denoise_fn(x_next, sigma=sigma_next)
d_prime = (x_next - model_out_next) / sigma_next
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
# Renoise for next resampling step
Expand Down
3 changes: 3 additions & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
use_attention_bottleneck: bool,
diffusion_sigma_distribution: Distribution,
diffusion_sigma_data: int,
diffusion_dynamic_threshold: float,
out_channels: Optional[int] = None,
):
super().__init__()
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(
net=self.unet,
sigma_distribution=diffusion_sigma_distribution,
sigma_data=diffusion_sigma_data,
dynamic_threshold=diffusion_dynamic_threshold,
)

def forward(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -105,6 +107,7 @@ def __init__(self, *args, **kwargs):
use_learned_time_embedding=True,
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
diffusion_sigma_data=0.1,
diffusion_dynamic_threshold=0.95,
)
super().__init__(*args, **{**default_kwargs, **kwargs})

Expand Down

0 comments on commit a5f1069

Please sign in to comment.