Skip to content

Commit

Permalink
feat: diffusion multiencoder default to 0 out layer
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 27, 2022
1 parent 60adadc commit da94c00
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
3 changes: 3 additions & 0 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def __init__(
encoder_channels: int,
bottleneck: Optional[Bottleneck] = None,
encoder_num_blocks: Optional[Sequence[int]] = None,
encoder_out_layers: int = 0,
**kwargs
):
self.in_channels = in_channels
Expand All @@ -125,6 +126,7 @@ def __init__(
patch_blocks=patch_blocks,
patch_factor=patch_factor,
num_layers=encoder_depth,
num_layers_out=encoder_out_layers,
latent_channels=encoder_channels,
multipliers=multipliers,
factors=factors,
Expand Down Expand Up @@ -159,6 +161,7 @@ def forward( # type: ignore
latent = self.encode(x)

channels_list = self.multiencoder.decode(latent)
print([x.shape for x in channels_list])
loss = self.diffusion(x, channels_list=channels_list, **kwargs)
return (loss, info) if with_info else loss

Expand Down
17 changes: 13 additions & 4 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,22 +1192,27 @@ def __init__(
self,
in_channels: int,
channels: int,
patch_factor: int,
patch_blocks: int,
resnet_groups: int,
kernel_multiplier_downsample: int,
num_layers: int,
num_layers_out: int,
latent_channels: int,
multipliers: Sequence[int],
factors: Sequence[int],
num_blocks: Sequence[int],
patch_factor: int = 2,
):
super().__init__()
# Latent space factor
self.factor = (patch_factor ** patch_blocks) * prod(factors[0:num_layers])
self.num_layers = num_layers
self.num_layers_out = num_layers_out
self.channels_list = self.get_channels_list(
in_channels, channels, multipliers, num_layers
in_channels, channels, multipliers, num_layers, num_layers_out
)

assert num_layers_out <= num_layers
assert (
len(multipliers) >= num_layers + 1
and len(factors) >= num_layers
Expand Down Expand Up @@ -1261,7 +1266,7 @@ def __init__(
use_skip=False,
extract_channels=channels * multipliers[i],
)
for i in reversed(range(num_layers))
for i in reversed(range(num_layers - num_layers_out, num_layers))
]
)

Expand All @@ -1278,9 +1283,12 @@ def get_channels_list(
channels: int,
multipliers: Sequence[int],
num_layers: int,
num_layers_out: int,
) -> List[int]:
channels_list = [in_channels]
channels_list += [channels * m for m in multipliers[1 : num_layers + 1]]
empty_channels = num_layers - num_layers_out
channels_list = [0] * empty_channels + channels_list[-num_layers_out - 1 :]
return channels_list

def encode(self, x: Tensor) -> Tensor:
Expand All @@ -1297,6 +1305,7 @@ def decode(self, latent: Tensor) -> List[Tensor]:
for upsample in self.upsamples:
channels_list += [channels]
x, channels = upsample(x)
x = self.to_out(x)
if self.num_layers_out == self.num_layers:
x = self.to_out(x)
channels_list += [x]
return channels_list[::-1]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.47",
version="0.0.48",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit da94c00

Please sign in to comment.