From 57f9ee4e1aad28e7c715a8e889a344d35c58d23d Mon Sep 17 00:00:00 2001 From: lucidrains Date: Mon, 2 Oct 2023 15:09:42 -0700 Subject: [PATCH] buy into the conclusions of a new research paper --- README.md | 11 ++++++ setup.py | 2 +- voicebox_pytorch/voicebox_pytorch.py | 53 ++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a95cadd..02c8679 100644 --- a/README.md +++ b/README.md @@ -203,3 +203,14 @@ sampled = cfm_wrapper.sample(cond = cond) # (2, 1024, 512) year = {2023} } ``` + +```bibtex +@misc{darcet2023vision, + title = {Vision Transformers Need Registers}, + author = {Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski}, + year = {2023}, + eprint = {2309.16588}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` diff --git a/setup.py b/setup.py index 115086b..0de28d9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'voicebox-pytorch', packages = find_packages(exclude=[]), - version = '0.1.14', + version = '0.2.0', license='MIT', description = 'Voicebox - Pytorch', author = 'Phil Wang', diff --git a/voicebox_pytorch/voicebox_pytorch.py b/voicebox_pytorch/voicebox_pytorch.py index ab1e36d..2cc0e00 100644 --- a/voicebox_pytorch/voicebox_pytorch.py +++ b/voicebox_pytorch/voicebox_pytorch.py @@ -167,7 +167,7 @@ def forward(self, x): # https://arxiv.org/abs/2104.09864 class RotaryEmbedding(Module): - def __init__(self, dim, theta = 10000): + def __init__(self, dim, theta = 50000): super().__init__() inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) @@ -176,8 +176,11 @@ def __init__(self, dim, theta = 10000): def device(self): return self.inv_freq.device - def forward(self, seq_len): - t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) + def forward(self, t): + if not torch.is_tensor(t): + t = torch.arange(t, device = self.device) + + t = t.type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) freqs = torch.cat((freqs, freqs), dim = -1) return freqs @@ -310,7 +313,8 @@ def __init__( dim_head = 64, heads = 8, ff_mult = 4, - attn_dropout=0, + attn_dropout = 0, + num_register_tokens = 0., attn_flash = False, adaptive_rmsnorm = False, adaptive_rmsnorm_cond_dim_in = None, @@ -323,6 +327,9 @@ def __init__( self.rotary_emb = RotaryEmbedding(dim = dim_head) + self.num_register_tokens = num_register_tokens + self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim)) + if adaptive_rmsnorm: rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in) else: @@ -344,20 +351,48 @@ def __init__( self.final_norm = RMSNorm(dim) + @property + def device(self): + return next(self.parameters()).device + def forward( self, x, mask = None, adaptive_rmsnorm_cond = None ): + batch, seq_len, *_ = x.shape + + register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch) + + # add register tokens to the left + + x, ps = pack([register_tokens, x], 'b * d') + + if exists(mask): + mask = F.pad(mask, (self.num_register_tokens, 0), value = True) + + # keep track of skip connections + skip_connects = [] - rotary_emb = self.rotary_emb(x.shape[-2]) + # rotary embeddings + + main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long) + register_positions = torch.arange(self.num_register_tokens, device = self.device, dtype = torch.long) + register_positions -= 10000 + positions = torch.cat((register_positions, main_positions)) + + rotary_emb = self.rotary_emb(positions) + + # adaptive rmsnorm rmsnorm_kwargs = dict() if exists(adaptive_rmsnorm_cond): rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond) + # going through the attention layers + for skip_combiner, attn_prenorm, attn, ff_prenorm, ff in self.layers: # in the paper, they use a u-net like skip connection @@ -376,6 +411,10 @@ def forward( ff_input = ff_prenorm(x, **rmsnorm_kwargs) x = ff(ff_input) + x + # remove the register tokens + + _, x = unpack(x, ps, 'b * d') + return self.final_norm(x) # encoder decoders @@ -783,6 +822,7 @@ def __init__( conv_pos_embed_groups = None, attn_dropout=0, attn_flash = False, + num_register_tokens = 16, p_drop_prob = 0.3, # p_drop in paper frac_lengths_mask: Tuple[float, float] = (0.7, 1.), condition_on_text = True @@ -837,8 +877,9 @@ def __init__( dim_head = dim_head, heads = heads, ff_mult = ff_mult, - attn_dropout=attn_dropout, + attn_dropout= attn_dropout, attn_flash = attn_flash, + num_register_tokens = num_register_tokens, adaptive_rmsnorm = True, adaptive_rmsnorm_cond_dim_in = time_hidden_dim )