diff --git a/audio_diffusion_pytorch/__init__.py b/audio_diffusion_pytorch/__init__.py index 41fde5c..470d2d1 100644 --- a/audio_diffusion_pytorch/__init__.py +++ b/audio_diffusion_pytorch/__init__.py @@ -32,6 +32,7 @@ AutoEncoder1d, MultiEncoder1d, T5Embedder, + Tanh, UNet1d, UNetConditional1d, Variational, diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 46bbfac..4a59268 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -1,6 +1,6 @@ import math from math import pi -from typing import Any, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import torch.nn as nn @@ -9,7 +9,7 @@ from einops_exts import rearrange_many from torch import Tensor, einsum -from .utils import default, exists, prod +from .utils import default, exists, prod, to_list """ Utils @@ -1341,6 +1341,15 @@ def forward( return (out, dict(loss=loss, mean=mean, logvar=logvar)) if with_info else out +class Tanh(Bottleneck): + def forward( + self, x: Tensor, with_info: bool = False + ) -> Union[Tensor, Tuple[Tensor, Any]]: + x = torch.tanh(x) + info: Dict = dict() + return (x, info) if with_info else x + + class AutoEncoder1d(nn.Module): def __init__( self, @@ -1353,12 +1362,12 @@ def __init__( factors: Sequence[int], num_blocks: Sequence[int], use_noisy: bool = False, - bottleneck: Optional[Bottleneck] = None, + bottleneck: Union[Bottleneck, List[Bottleneck]] = [], use_magnitude_channels: bool = False, ): super().__init__() num_layers = len(multipliers) - 1 - self.bottleneck = bottleneck + self.bottlenecks = to_list(bottleneck) self.use_noisy = use_noisy self.use_magnitude_channels = use_magnitude_channels @@ -1424,8 +1433,8 @@ def encode( xs += [x] info = dict(xs=xs) - if exists(self.bottleneck): - x, info_bottleneck = self.bottleneck(x, with_info=True) + for bottleneck in self.bottlenecks: + x, info_bottleneck = bottleneck(x, with_info=True) info = {**info, **info_bottleneck} return (x, info) if with_info else x diff --git a/setup.py b/setup.py index 9a9cbfc..3ed067e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.70", + version="0.0.71", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",