Skip to content

Commit

Permalink
feat: remove unsucessful convmean, provide context only during downsa…
Browse files Browse the repository at this point in the history
…mpling
  • Loading branch information
flavioschneider committed Sep 6, 2022
1 parent d12456b commit f46557b
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 11 deletions.
4 changes: 0 additions & 4 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def __init__(
channels: int,
patch_size: int,
kernel_sizes_init: Sequence[int],
out_means: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
Expand Down Expand Up @@ -51,7 +50,6 @@ def __init__(
resnet_groups=resnet_groups,
kernel_multiplier_downsample=kernel_multiplier_downsample,
kernel_sizes_init=kernel_sizes_init,
out_means=out_means,
multipliers=multipliers,
factors=factors,
num_blocks=num_blocks,
Expand Down Expand Up @@ -100,7 +98,6 @@ def __init__(self, *args, **kwargs):
channels=128,
patch_size=16,
kernel_sizes_init=[1, 3, 7],
out_means=4,
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
num_blocks=[2, 2, 2, 2, 2, 2],
Expand Down Expand Up @@ -136,7 +133,6 @@ def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
in_channels=in_channels,
channels=128,
patch_size=16,
out_means=4,
kernel_sizes_init=[1, 3, 7],
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
Expand Down
8 changes: 2 additions & 6 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,6 @@ def __init__(
use_skip_scale: bool,
use_attention_bottleneck: bool,
out_channels: Optional[int] = None,
out_means: int = 1,
context_channels: Optional[Sequence[int]] = None,
):
super().__init__()
Expand Down Expand Up @@ -802,11 +801,10 @@ def __init__(
attention_features=attention_features,
)

context_channels = context_channels + [0] # Upsample skips first context
self.upsamples = nn.ModuleList(
[
UpsampleBlock1d(
in_channels=channels * multipliers[i + 1] + context_channels[i + 2],
in_channels=channels * multipliers[i + 1],
out_channels=channels * multipliers[i],
time_context_features=time_context_features,
num_layers=num_blocks[i] + (1 if attentions[i] else 0),
Expand All @@ -833,8 +831,7 @@ def __init__(
num_groups=resnet_groups,
time_context_features=time_context_features,
),
ConvMean1d(
num_means=out_means,
Conv1d(
in_channels=channels,
out_channels=out_channels * patch_size,
kernel_size=1,
Expand Down Expand Up @@ -889,7 +886,6 @@ def forward(
for i, upsample in enumerate(self.upsamples):
skips = skips_list.pop()
x = upsample(x, skips, t)
x = self.add_context(x, context, layer=len(self.upsamples) - i)

x = self.to_out(x) # t?

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.21",
version="0.0.22",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit f46557b

Please sign in to comment.