From 42bf374f3e0ce00aabc72ac1f50bfc6811b94af5 Mon Sep 17 00:00:00 2001 From: Flavio Schneider Date: Sun, 25 Sep 2022 16:10:36 +0200 Subject: [PATCH] fix: cross attention skip connection --- audio_diffusion_pytorch/modules.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/audio_diffusion_pytorch/modules.py b/audio_diffusion_pytorch/modules.py index de0f5f0..b1ff82b 100644 --- a/audio_diffusion_pytorch/modules.py +++ b/audio_diffusion_pytorch/modules.py @@ -421,7 +421,7 @@ def __init__( def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor: x = self.attention(x) + x if self.use_cross_attention: - x = self.cross_attention(x, context=context) + x = self.cross_attention(x, context=context) + x x = self.feed_forward(x) + x return x diff --git a/setup.py b/setup.py index 65070fe..6a31af3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="audio-diffusion-pytorch", packages=find_packages(exclude=[]), - version="0.0.44", + version="0.0.45", license="MIT", description="Audio Diffusion - PyTorch", long_description_content_type="text/markdown",