Skip to content

Commit

Permalink
feat: add magnitude channels option, change to quantile norm
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 5, 2022
1 parent 330c60d commit aaaa699
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 39 deletions.
1 change: 1 addition & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def get_default_model_kwargs():
use_nearest_upsample=False,
use_skip_scale=True,
use_context_time=True,
use_magnitude_channels=False,
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
diffusion_sigma_data=0.1,
diffusion_dynamic_threshold=0.0,
Expand Down
54 changes: 31 additions & 23 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from einops_exts import rearrange_many
from torch import Tensor, einsum

from .utils import default, exists, prod, wave_norm, wave_unnorm
from .utils import default, exists, prod

"""
Utils
Expand Down Expand Up @@ -785,6 +785,15 @@ def forward(
return x


def get_norm_scale(x: Tensor, quantile: float):
return torch.quantile(x.abs(), quantile, dim=-1, keepdim=True) + 1e-7


def merge_magnitude_channels(x: Tensor):
waveform, magnitude = torch.chunk(x, chunks=2, dim=1)
return torch.sigmoid(waveform) * torch.tanh(magnitude)


"""
UNet
"""
Expand All @@ -809,8 +818,8 @@ def __init__(
use_nearest_upsample: bool,
use_skip_scale: bool,
use_context_time: bool,
norm: float = 0.0,
norm_alpha: float = 20.0,
use_magnitude_channels: bool,
norm_quantile: float = 0.0,
out_channels: Optional[int] = None,
context_features: Optional[int] = None,
context_channels: Optional[Sequence[int]] = None,
Expand All @@ -824,9 +833,6 @@ def __init__(
use_context_channels = len(context_channels) > 0
context_mapping_features = None

self.use_norm = norm > 0.0
self.norm = norm
self.norm_alpha = norm_alpha
self.num_layers = num_layers
self.use_context_time = use_context_time
self.use_context_features = use_context_features
Expand All @@ -841,6 +847,10 @@ def __init__(
self.has_context = has_context
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]

self.use_norm = norm_quantile > 0.0
self.norm_quantile = norm_quantile
self.use_magnitude_channels = use_magnitude_channels

assert (
len(factors) == num_layers
and len(attentions) >= num_layers
Expand Down Expand Up @@ -943,7 +953,7 @@ def __init__(

self.to_out = Unpatcher(
in_channels=channels,
out_channels=out_channels,
out_channels=out_channels * (2 if use_magnitude_channels else 1),
blocks=patch_blocks,
factor=patch_factor,
context_mapping_features=context_mapping_features,
Expand Down Expand Up @@ -1002,10 +1012,11 @@ def forward(
# Concat context channels at layer 0 if provided
channels = self.get_channels(channels_list, layer=0)
x = torch.cat([x, channels], dim=1) if exists(channels) else x
# Compute mapping from time and features
mapping = self.get_mapping(time, features)

if self.use_norm:
x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha)
# Compute norm scale
scale = get_norm_scale(x, self.norm_quantile) if self.use_norm else 1.0
x = x / scale

x = self.to_in(x, mapping)
skips_list = [x]
Expand All @@ -1026,10 +1037,10 @@ def forward(
x += skips_list.pop()
x = self.to_out(x, mapping)

if self.use_norm:
x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha)
if self.use_magnitude_channels:
x = merge_magnitude_channels(x)

return x
return x * scale


class FixedEmbedding(nn.Module):
Expand Down Expand Up @@ -1130,16 +1141,13 @@ def __init__(
num_blocks: Sequence[int],
use_noisy: bool = False,
bottleneck: Optional[Bottleneck] = None,
norm: float = 0.0,
norm_alpha: float = 20.0,
use_magnitude_channels: bool = False,
):
super().__init__()
num_layers = len(multipliers) - 1
self.bottleneck = bottleneck
self.use_noisy = use_noisy
self.use_norm = norm > 0.0
self.norm = norm
self.norm_alpha = norm_alpha
self.use_magnitude_channels = use_magnitude_channels

assert len(factors) >= num_layers and len(num_blocks) >= num_layers

Expand Down Expand Up @@ -1181,16 +1189,14 @@ def __init__(

self.to_out = Unpatcher(
in_channels=channels * (use_noisy + 1),
out_channels=in_channels,
out_channels=in_channels * (2 if use_magnitude_channels else 1),
blocks=patch_blocks,
factor=patch_factor,
)

def encode(
self, x: Tensor, with_info: bool = False
) -> Union[Tensor, Tuple[Tensor, Any]]:
if self.use_norm:
x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha)

x = self.to_in(x)
for downsample in self.downsamples:
Expand All @@ -1206,12 +1212,14 @@ def decode(self, x: Tensor) -> Tensor:
if self.use_noisy:
x = torch.cat([x, torch.randn_like(x)], dim=1)
x = upsample(x)

if self.use_noisy:
x = torch.cat([x, torch.randn_like(x)], dim=1)

x = self.to_out(x)

if self.use_norm:
x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha)
if self.use_magnitude_channels:
x = merge_magnitude_channels(x)

return x

Expand Down
15 changes: 0 additions & 15 deletions audio_diffusion_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,3 @@ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:

def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)


def wave_norm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor:
x = x.clip(-1, 1)
x = 2 * torch.sigmoid(alpha * x) - 1
x = x.clip(-1, 1)
return x * peak


def wave_unnorm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor:
x = x / peak
x = x.clip(-1, 1)
x = (1.0 / alpha) * torch.log((x + 1) / (1 - x))
x = x.clip(-1, 1)
return x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.56",
version="0.0.57",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit aaaa699

Please sign in to comment.