diff --git a/audio_diffusion_pytorch/model.py b/audio_diffusion_pytorch/model.py index 8fc009e..e466869 100644 --- a/audio_diffusion_pytorch/model.py +++ b/audio_diffusion_pytorch/model.py @@ -23,7 +23,6 @@ def __init__( channels: int, patch_size: int, kernel_sizes_init: Sequence[int], - out_means: int, multipliers: Sequence[int], factors: Sequence[int], num_blocks: Sequence[int], @@ -51,7 +50,6 @@ def __init__( resnet_groups=resnet_groups, kernel_multiplier_downsample=kernel_multiplier_downsample, kernel_sizes_init=kernel_sizes_init, - out_means=out_means, multipliers=multipliers, factors=factors, num_blocks=num_blocks, @@ -100,7 +98,6 @@ def __init__(self, *args, **kwargs): channels=128, patch_size=16, kernel_sizes_init=[1, 3, 7], - out_means=4, multipliers=[1, 2, 4, 4, 4, 4, 4], factors=[4, 4, 4, 2, 2, 2], num_blocks=[2, 2, 2, 2, 2, 2], @@ -136,7 +133,6 @@ def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs): in_channels=in_channels, channels=128, patch_size=16, - out_means=4, kernel_sizes_init=[1, 3, 7], multipliers=[1, 2, 4, 4, 4, 4, 4], factors=[4, 4, 4, 2, 2, 2], diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index 3d637ad..8bf12b1 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -724,7 +724,6 @@ def __init__( use_skip_scale: bool, use_attention_bottleneck: bool, out_channels: Optional[int] = None, - out_means: int = 1, context_channels: Optional[Sequence[int]] = None, ): super().__init__() @@ -802,11 +801,10 @@ def __init__( attention_features=attention_features, ) - context_channels = context_channels + [0] # Upsample skips first context self.upsamples = nn.ModuleList( [ UpsampleBlock1d( - in_channels=channels * multipliers[i + 1] + context_channels[i + 2], + in_channels=channels * multipliers[i + 1], out_channels=channels * multipliers[i], time_context_features=time_context_features, num_layers=num_blocks[i] + (1 if attentions[i] else 0), @@ -833,8 +831,7 @@ def __init__( num_groups=resnet_groups, time_context_features=time_context_features, ), - ConvMean1d( - num_means=out_means, + Conv1d( in_channels=channels, out_channels=out_channels * patch_size, kernel_size=1, @@ -889,7 +886,6 @@ def forward( for i, upsample in enumerate(self.upsamples): skips = skips_list.pop() x = upsample(x, skips, t) - x = self.add_context(x, context, layer=len(self.upsamples) - i) x = self.to_out(x) # t? diff --git a/setup.py b/setup.py index bb71c87..18b93f8 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.21", + version="0.0.22", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",