Skip to content

Commit

Permalink
feat: add convout kernels option
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 8, 2022
1 parent 28ec933 commit e0ba36c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 13 deletions.
2 changes: 2 additions & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
diffusion_dynamic_threshold: float,
out_channels: Optional[int] = None,
context_channels: Optional[Sequence[int]] = None,
**kwargs
):
super().__init__()

Expand All @@ -66,6 +67,7 @@ def __init__(
use_skip_scale=use_skip_scale,
out_channels=out_channels,
context_channels=context_channels,
**kwargs
)

self.diffusion = Diffusion(
Expand Down
54 changes: 42 additions & 12 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@ def ConvTranspose1d(*args, **kwargs) -> nn.Module:
return nn.ConvTranspose1d(*args, **kwargs)


class ConvMean1d(nn.Module):
def __init__(self, num_means: int, *args, **kwargs):
super().__init__()
self.convs = nn.ModuleList([Conv1d(*args, **kwargs) for _ in range(num_means)])

def forward(self, x: Tensor) -> Tensor:
xs = torch.stack([conv(x) for conv in self.convs])
x = reduce(xs, "n b c t -> b c t", "mean")
return x


def Downsample1d(
in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
) -> nn.Module:
Expand Down Expand Up @@ -709,6 +698,40 @@ def forward(self, x: Tensor, t: Optional[Tensor] = None) -> Tensor:
"""


class ConvOut1d(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, kernel_sizes: Sequence[int]
):
super().__init__()

self.block1 = nn.ModuleList(
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
)
for kernel_size in kernel_sizes
)

self.block2 = nn.ModuleList(
Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=(kernel_size - 1) // 2,
)
for kernel_size in kernel_sizes
)

def forward(self, x: Tensor) -> Tensor:
xs = torch.stack([x] + [conv(x) for conv in self.block1])
x = reduce(xs, "n b c t -> b c t", "sum")
xs = torch.stack([x] + [conv(x) for conv in self.block2])
x = reduce(xs, "n b c t -> b c t", "sum")
return x


class UNet1d(nn.Module):
def __init__(
self,
Expand All @@ -730,6 +753,7 @@ def __init__(
use_attention_bottleneck: bool,
out_channels: Optional[int] = None,
context_channels: Optional[Sequence[int]] = None,
kernel_sizes_out: Optional[Sequence[int]] = None,
):
super().__init__()

Expand Down Expand Up @@ -835,14 +859,20 @@ def __init__(
in_channels=channels + context_channels[1],
out_channels=channels,
num_groups=resnet_groups,
time_context_features=time_context_features,
),
Conv1d(
in_channels=channels,
out_channels=out_channels * patch_size,
kernel_size=1,
),
Rearrange("b (c p) l -> b c (l p)", p=patch_size),
ConvOut1d(
in_channels=out_channels,
out_channels=out_channels,
kernel_sizes=kernel_sizes_out,
)
if exists(kernel_sizes_out)
else nn.Identity(),
)

def get_context(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.26",
version="0.0.27",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit e0ba36c

Please sign in to comment.