diff --git a/setup.py b/setup.py index 25bca34..00bea0c 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.2.1', + version = '0.2.2', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index 71297f4..07f3d8c 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -329,7 +329,10 @@ def __init__( self.rotary_emb = RotaryEmbedding(dim = dim_head) self.num_register_tokens = num_register_tokens - self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) + self.has_register_tokens = num_register_tokens > 0 + + if self.has_register_tokens: + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) if adaptive_rmsnorm: rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in) @@ -364,14 +367,15 @@ def forward( ): batch, seq_len, *_ = x.shape - register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) - # add register tokens to the left - x, ps = pack([register_tokens, x], 'b * d') + if self.has_register_tokens: + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) + + x, ps = pack([register_tokens, x], 'b * d') - if exists(mask): - mask = F.pad(mask, (self.num_register_tokens, 0), value = True) + if exists(mask): + mask = F.pad(mask, (self.num_register_tokens, 0), value = True) # keep track of skip connections @@ -414,7 +418,8 @@ def forward( # remove the register tokens - _, x = unpack(x, ps, 'b * d') + if self.has_register_tokens: + _, x = unpack(x, ps, 'b * d') return self.final_norm(x)