From 58db626a7d551b3cf8fd7dab510e20b7cad5d7ca Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Sat, 15 Oct 2022 10:52:26 +0200 Subject: [PATCH] feat: add mean and logvar info to variational bottleneck --- audio_diffusion_pytorch/modules.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index b8248c5..ab3c19d 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -1211,7 +1211,7 @@ def forward( logvar = torch.clamp(logvar, -30.0, 20.0) out = gaussian_sample(mean, logvar) loss = kl_loss(mean, logvar) * self.loss_weight - return (out, dict(loss=loss)) if with_info else out + return (out, dict(loss=loss, mean=mean, logvar=logvar)) if with_info else out class AutoEncoder1d(nn.Module): diff --git a/setup.py b/setup.py index 18d770d..a5937bc 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.64", + version="0.0.65", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",