diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 53830ea..dccb844 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -44,6 +44,7 @@ def __init__( diffusion_dynamic_threshold: float, out_channels: Optional[int] = None, context_channels: Optional[Sequence[int]] = None, + **kwargs ): super().__init__() @@ -66,6 +67,7 @@ def __init__( use_skip_scale=use_skip_scale, out_channels=out_channels, context_channels=context_channels, + **kwargs ) self.diffusion = Diffusion( diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 44ffd70..27ec905 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -25,17 +25,6 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module: return nn.ConvTranspose1d(*args, **kwargs) -class ConvMean1d(nn.Module): - def __init__(self, num_means: int, *args, **kwargs): - super().__init__() - self.convs = nn.ModuleList([Conv1d(*args, **kwargs) for _ in range(num_means)]) - - def forward(self, x: Tensor) -> Tensor: - xs = torch.stack([conv(x) for conv in self.convs]) - x = reduce(xs, "n b c t -> b c t", "mean") - return x - - def Downsample1d( in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 ) -> nn.Module: @@ -709,6 +698,40 @@ def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor: """ +class ConvOut1d(nn.Module): + def __init__( + self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int] + ): + super().__init__() + + self.block1 = nn.ModuleList( + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + for kernel_size in kernel_sizes + ) + + self.block2 = nn.ModuleList( + Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=(kernel_size - 1) // 2, + ) + for kernel_size in kernel_sizes + ) + + def forward(self, x: Tensor) -> Tensor: + xs = torch.stack([x] + [conv(x) for conv in self.block1]) + x = reduce(xs, "n b c t -> b c t", "sum") + xs = torch.stack([x] + [conv(x) for conv in self.block2]) + x = reduce(xs, "n b c t -> b c t", "sum") + return x + + class UNet1d(nn.Module): def __init__( self, @@ -730,6 +753,7 @@ def __init__( use_attention_bottleneck: bool, out_channels: Optional[int] = None, context_channels: Optional[Sequence[int]] = None, + kernel_sizes_out: Optional[Sequence[int]] = None, ): super().__init__() @@ -835,7 +859,6 @@ def __init__( in_channels=channels + context_channels[1], out_channels=channels, num_groups=resnet_groups, - time_context_features=time_context_features, ), Conv1d( in_channels=channels, @@ -843,6 +866,13 @@ def __init__( kernel_size=1, ), Rearrange("b (c p) l -> b c (l p)", p=patch_size), + ConvOut1d( + in_channels=out_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes_out, + ) + if exists(kernel_sizes_out) + else nn.Identity(), ) def get_context( diff --git a/setup.py b/setup.py index 83cd055..829cc90 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.26", + version="0.0.27", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",