Skip to content

Commit

Permalink
do a final norm after all the transformer layers, and for duration pr…
Browse files Browse the repository at this point in the history
…ediction, project each token to dimension of 1 and squeeze out
  • Loading branch information
lucidrains committed Aug 6, 2023
1 parent 2cd46a1 commit 156bd6a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.5',
version = '0.0.6',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
15 changes: 12 additions & 3 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def __init__(
FeedForward(dim = dim, mult = ff_mult)
]))

self.final_norm = RMSNorm(dim)

def forward(self, x):
skip_connects = []

Expand All @@ -243,7 +245,7 @@ def forward(self, x):
x = attn(x, rotary_emb = rotary_emb) + x
x = ff(x) + x

return x
return self.final_norm(x)

# both duration and main denoising model are transformers

Expand Down Expand Up @@ -286,6 +288,11 @@ def __init__(
attn_flash = attn_flash
)

self.to_pred = nn.Sequential(
nn.Linear(dim, 1),
Rearrange('... 1 -> ...')
)

@torch.inference_mode()
def forward_with_cond_scale(
self,
Expand Down Expand Up @@ -344,8 +351,6 @@ def forward(
return F.l1_loss(x, target)

loss = F.l1_loss(x, target, reduction = 'none')

loss = reduce(loss, 'b n d -> b n', 'mean')
loss = loss.masked_fill(mask, 0.)

# masked mean
Expand Down Expand Up @@ -397,6 +402,8 @@ def __init__(
attn_flash = attn_flash
)

self.to_pred = nn.Linear(dim, dim, bias = False)

@torch.inference_mode()
def forward_with_cond_scale(
self,
Expand Down Expand Up @@ -465,6 +472,8 @@ def forward(

x = self.transformer(x)

x = self.to_pred(x)

# split out time embedding

_, x = unpack(x, ps, 'b * d')
Expand Down

0 comments on commit 156bd6a

Please sign in to comment.