Skip to content

Commit

Permalink
feat: update cross attention, option to change attention blocks, defa…
Browse files Browse the repository at this point in the history
…ult to 16 patch factor
  • Loading branch information
flavioschneider committed Sep 25, 2022
1 parent 9f57a87 commit d60eefa
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 236 deletions.
15 changes: 9 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,16 @@ from audio_diffusion_pytorch import UNet1d
unet = UNet1d(
in_channels=1,
channels=128,
patch_blocks=4,
patch_factor=2,
patch_blocks=16,
patch_factor=1,
kernel_sizes_init=[1, 3, 7],
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
attentions=[False, False, False, True, True, True],
attentions=[0, 0, 0, 1, 1, 1, 1],
num_blocks=[2, 2, 2, 2, 2, 2],
attention_heads=8,
attention_features=64,
attention_multiplier=2,
use_attention_bottleneck=True,
resnet_groups=8,
kernel_multiplier_downsample=2,
use_nearest_upsample=False,
Expand Down Expand Up @@ -229,16 +228,20 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
- [x] Add elucidated diffusion.
- [x] Add ancestral DPM2 sampler.
- [x] Add dynamic thresholding.
- [x] Add (variational) autoencoder option to compress audio before diffusion.
- [x] Add (variational) autoencoder option to compress audio before diffusion (removed).
- [x] Fix inpainting and make it work with ADPM2 sampler.
- [x] Add trainer with experiments.
- [x] Add diffusion upsampler.
- [x] Add ancestral euler sampler `AEulerSampler`.
- [x] Add diffusion autoencoder.
- [x] Add diffusion upsampler.
- [x] Add autoencoder bottleneck option for quantization.
- [x] Add option to provide context tokens (resnet cross attention).
- [x] Add option to provide context tokens (cross attention).
- [x] Add conditional model with classifier-free guidance.
- [x] Add option to provide context features mapping.
- [x] Add option to change number of (cross) attention blocks.
- [ ] Add flash attention.


## Appreciation

Expand Down
7 changes: 3 additions & 4 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,17 +202,16 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
def get_default_model_kwargs():
return dict(
channels=128,
patch_blocks=4,
patch_factor=2,
patch_blocks=1,
patch_factor=16,
kernel_sizes_init=[1, 3, 7],
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
num_blocks=[2, 2, 2, 2, 2, 2],
attentions=[False, False, False, True, True, True],
attentions=[0, 0, 0, 1, 1, 1, 1],
attention_heads=8,
attention_features=64,
attention_multiplier=2,
use_attention_bottleneck=True,
resnet_groups=8,
kernel_multiplier_downsample=2,
use_nearest_upsample=False,
Expand Down
Loading

0 comments on commit d60eefa

Please sign in to comment.