From aaaa6998d218d445c82c0473de9c3eaf8037ec2f Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Wed, 5 Oct 2022 17:00:06 +0200 Subject: [PATCH] feat: add magnitude channels option, change to quantile norm --- audio_diffusion_pytorch/model.py | 1 + audio_diffusion_pytorch/modules.py | 54 +++++++++++++++++------------- audio_diffusion_pytorch/utils.py | 15 --------- setup.py | 2 +- 4 files changed, 33 insertions(+), 39 deletions(-) diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index aae678c..c3e3b3e 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -220,6 +220,7 @@ def get_default_model_kwargs(): use_nearest_upsample=False, use_skip_scale=True, use_context_time=True, + use_magnitude_channels=False, diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0), diffusion_sigma_data=0.1, diffusion_dynamic_threshold=0.0, diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index f290dba..9061f43 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -8,7 +8,7 @@ from einops_exts import rearrange_many from torch import Tensor, einsum -from .utils import default, exists, prod, wave_norm, wave_unnorm +from .utils import default, exists, prod """ Utils @@ -785,6 +785,15 @@ def forward( return x +def get_norm_scale(x: Tensor, quantile: float): + return torch.quantile(x.abs(), quantile, dim=-1, keepdim=True) + 1e-7 + + +def merge_magnitude_channels(x: Tensor): + waveform, magnitude = torch.chunk(x, chunks=2, dim=1) + return torch.sigmoid(waveform) * torch.tanh(magnitude) + + """ UNet """ @@ -809,8 +818,8 @@ def __init__( use_nearest_upsample: bool, use_skip_scale: bool, use_context_time: bool, - norm: float = 0.0, - norm_alpha: float = 20.0, + use_magnitude_channels: bool, + norm_quantile: float = 0.0, out_channels: Optional[int] = None, context_features: Optional[int] = None, context_channels: Optional[Sequence[int]] = None, @@ -824,9 +833,6 @@ def __init__( use_context_channels = len(context_channels) > 0 context_mapping_features = None - self.use_norm = norm > 0.0 - self.norm = norm - self.norm_alpha = norm_alpha self.num_layers = num_layers self.use_context_time = use_context_time self.use_context_features = use_context_features @@ -841,6 +847,10 @@ def __init__( self.has_context = has_context self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))] + self.use_norm = norm_quantile > 0.0 + self.norm_quantile = norm_quantile + self.use_magnitude_channels = use_magnitude_channels + assert ( len(factors) == num_layers and len(attentions) >= num_layers @@ -943,7 +953,7 @@ def __init__( self.to_out = Unpatcher( in_channels=channels, - out_channels=out_channels, + out_channels=out_channels * (2 if use_magnitude_channels else 1), blocks=patch_blocks, factor=patch_factor, context_mapping_features=context_mapping_features, @@ -1002,10 +1012,11 @@ def forward( # Concat context channels at layer 0 if provided channels = self.get_channels(channels_list, layer=0) x = torch.cat([x, channels], dim=1) if exists(channels) else x + # Compute mapping from time and features mapping = self.get_mapping(time, features) - - if self.use_norm: - x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha) + # Compute norm scale + scale = get_norm_scale(x, self.norm_quantile) if self.use_norm else 1.0 + x = x / scale x = self.to_in(x, mapping) skips_list = [x] @@ -1026,10 +1037,10 @@ def forward( x += skips_list.pop() x = self.to_out(x, mapping) - if self.use_norm: - x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha) + if self.use_magnitude_channels: + x = merge_magnitude_channels(x) - return x + return x * scale class FixedEmbedding(nn.Module): @@ -1130,16 +1141,13 @@ def __init__( num_blocks: Sequence[int], use_noisy: bool = False, bottleneck: Optional[Bottleneck] = None, - norm: float = 0.0, - norm_alpha: float = 20.0, + use_magnitude_channels: bool = False, ): super().__init__() num_layers = len(multipliers) - 1 self.bottleneck = bottleneck self.use_noisy = use_noisy - self.use_norm = norm > 0.0 - self.norm = norm - self.norm_alpha = norm_alpha + self.use_magnitude_channels = use_magnitude_channels assert len(factors) >= num_layers and len(num_blocks) >= num_layers @@ -1181,7 +1189,7 @@ def __init__( self.to_out = Unpatcher( in_channels=channels * (use_noisy + 1), - out_channels=in_channels, + out_channels=in_channels * (2 if use_magnitude_channels else 1), blocks=patch_blocks, factor=patch_factor, ) @@ -1189,8 +1197,6 @@ def __init__( def encode( self, x: Tensor, with_info: bool = False ) -> Union[Tensor, Tuple[Tensor, Any]]: - if self.use_norm: - x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha) x = self.to_in(x) for downsample in self.downsamples: @@ -1206,12 +1212,14 @@ def decode(self, x: Tensor) -> Tensor: if self.use_noisy: x = torch.cat([x, torch.randn_like(x)], dim=1) x = upsample(x) + if self.use_noisy: x = torch.cat([x, torch.randn_like(x)], dim=1) + x = self.to_out(x) - if self.use_norm: - x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha) + if self.use_magnitude_channels: + x = merge_magnitude_channels(x) return x diff --git a/audio_diffusion_pytorch/utils.py b/audio_diffusion_pytorch/utils.py index 96dd831..791da02 100644 --- a/audio_diffusion_pytorch/utils.py +++ b/audio_diffusion_pytorch/utils.py @@ -83,18 +83,3 @@ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor: return resample(waveforms, factor_in=1, factor_out=factor, **kwargs) - - -def wave_norm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor: - x = x.clip(-1, 1) - x = 2 * torch.sigmoid(alpha * x) - 1 - x = x.clip(-1, 1) - return x * peak - - -def wave_unnorm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor: - x = x / peak - x = x.clip(-1, 1) - x = (1.0 / alpha) * torch.log((x + 1) / (1 - x)) - x = x.clip(-1, 1) - return x diff --git a/setup.py b/setup.py index 883bb6d..afd12d8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.56", + version="0.0.57", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",