From 156bd6ada2017f3471b55e045b6628c56cd7b610 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 5 Aug 2023 17:55:57 -0700 Subject: [PATCH] do a final norm after all the transformer layers, and for duration prediction, project each token to dimension of 1 and squeeze out --- setup.py | 2 +- voicebox_pytorch/voicebox_pytorch.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index f9311eb..bf6edb0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.0.5', + version = '0.0.6', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index e316932..0e7de53 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -224,6 +224,8 @@ def __init__( FeedForward(dim = dim, mult = ff_mult) ])) + self.final_norm = RMSNorm(dim) + def forward(self, x): skip_connects = [] @@ -243,7 +245,7 @@ def forward(self, x): x = attn(x, rotary_emb = rotary_emb) + x x = ff(x) + x - return x + return self.final_norm(x) # both duration and main denoising model are transformers @@ -286,6 +288,11 @@ def __init__( attn_flash = attn_flash ) + self.to_pred = nn.Sequential( + nn.Linear(dim, 1), + Rearrange('... 1 -> ...') + ) + @torch.inference_mode() def forward_with_cond_scale( self, @@ -344,8 +351,6 @@ def forward( return F.l1_loss(x, target) loss = F.l1_loss(x, target, reduction = 'none') - - loss = reduce(loss, 'b n d -> b n', 'mean') loss = loss.masked_fill(mask, 0.) # masked mean @@ -397,6 +402,8 @@ def __init__( attn_flash = attn_flash ) + self.to_pred = nn.Linear(dim, dim, bias = False) + @torch.inference_mode() def forward_with_cond_scale( self, @@ -465,6 +472,8 @@ def forward( x = self.transformer(x) + x = self.to_pred(x) + # split out time embedding _, x = unpack(x, ps, 'b * d')