Skip to content

Commit

Permalink
feat: add v-diffusion class/distribution, refactor diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Oct 8, 2022
1 parent 125b938 commit 68da808
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 121 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ unet = UNet1d(
use_nearest_upsample=False,
use_skip_scale=True,
use_context_time=True,
use_magnitude_channels=False
)

x = torch.randn(3, 1, 2 ** 16)
Expand All @@ -151,13 +152,20 @@ y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz wit

#### Training
```python
from audio_diffusion_pytorch import Diffusion, LogNormalDistribution
from audio_diffusion_pytorch import KDiffusion, VDiffusion, LogNormalDistribution, VDistribution

diffusion = Diffusion(
# Either use KDiffusion
diffusion = KDiffusion(
net=unet,
sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
sigma_data=0.1,
dynamic_threshold=0.95
dynamic_threshold=0.0
)

# Or use VDiffusion
diffusion = VDiffusion(
net=unet,
sigma_distribution=VDistribution()
)

x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
Expand Down Expand Up @@ -239,6 +247,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
- [x] Add conditional model with classifier-free guidance.
- [x] Add option to provide context features mapping.
- [x] Add option to change number of (cross) attention blocks.
- [x] Add `VDiffusionn` option.
- [ ] Add flash attention.


Expand Down
3 changes: 3 additions & 0 deletions audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
Distribution,
KarrasSampler,
KarrasSchedule,
KDiffusion,
LogNormalDistribution,
Sampler,
Schedule,
SpanBySpanComposer,
VDiffusion,
VDistribution,
)
from .model import (
AudioDiffusionAutoencoder,
Expand Down
Loading

0 comments on commit 68da808

Please sign in to comment.