Skip to content

Commit

Permalink
buy into the conclusions of a new research paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 2, 2023
1 parent 68ac954 commit 57f9ee4
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,14 @@ sampled = cfm_wrapper.sample(cond = cond) # (2, 1024, 512)
year = {2023}
}
```

```bibtex
@misc{darcet2023vision,
title = {Vision Transformers Need Registers},
author = {Timothée Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
year = {2023},
eprint = {2309.16588},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'voicebox-pytorch',
packages = find_packages(exclude=[]),
version = '0.1.14',
version = '0.2.0',
license='MIT',
description = 'Voicebox - Pytorch',
author = 'Phil Wang',
Expand Down
53 changes: 47 additions & 6 deletions voicebox_pytorch/voicebox_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def forward(self, x):
# https://arxiv.org/abs/2104.09864

class RotaryEmbedding(Module):
def __init__(self, dim, theta = 10000):
def __init__(self, dim, theta = 50000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
Expand All @@ -176,8 +176,11 @@ def __init__(self, dim, theta = 10000):
def device(self):
return self.inv_freq.device

def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
def forward(self, t):
if not torch.is_tensor(t):
t = torch.arange(t, device = self.device)

t = t.type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs
Expand Down Expand Up @@ -310,7 +313,8 @@ def __init__(
dim_head = 64,
heads = 8,
ff_mult = 4,
attn_dropout=0,
attn_dropout = 0,
num_register_tokens = 0.,
attn_flash = False,
adaptive_rmsnorm = False,
adaptive_rmsnorm_cond_dim_in = None,
Expand All @@ -323,6 +327,9 @@ def __init__(

self.rotary_emb = RotaryEmbedding(dim = dim_head)

self.num_register_tokens = num_register_tokens
self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim))

if adaptive_rmsnorm:
rmsnorm_klass = partial(AdaptiveRMSNorm, cond_dim = adaptive_rmsnorm_cond_dim_in)
else:
Expand All @@ -344,20 +351,48 @@ def __init__(

self.final_norm = RMSNorm(dim)

@property
def device(self):
return next(self.parameters()).device

def forward(
self,
x,
mask = None,
adaptive_rmsnorm_cond = None
):
batch, seq_len, *_ = x.shape

register_tokens = repeat(self.register_tokens, 'n d -> b n d', b = batch)

# add register tokens to the left

x, ps = pack([register_tokens, x], 'b * d')

if exists(mask):
mask = F.pad(mask, (self.num_register_tokens, 0), value = True)

# keep track of skip connections

skip_connects = []

rotary_emb = self.rotary_emb(x.shape[-2])
# rotary embeddings

main_positions = torch.arange(seq_len, device = self.device, dtype = torch.long)
register_positions = torch.arange(self.num_register_tokens, device = self.device, dtype = torch.long)
register_positions -= 10000
positions = torch.cat((register_positions, main_positions))

rotary_emb = self.rotary_emb(positions)

# adaptive rmsnorm

rmsnorm_kwargs = dict()
if exists(adaptive_rmsnorm_cond):
rmsnorm_kwargs = dict(cond = adaptive_rmsnorm_cond)

# going through the attention layers

for skip_combiner, attn_prenorm, attn, ff_prenorm, ff in self.layers:

# in the paper, they use a u-net like skip connection
Expand All @@ -376,6 +411,10 @@ def forward(
ff_input = ff_prenorm(x, **rmsnorm_kwargs)
x = ff(ff_input) + x

# remove the register tokens

_, x = unpack(x, ps, 'b * d')

return self.final_norm(x)

# encoder decoders
Expand Down Expand Up @@ -783,6 +822,7 @@ def __init__(
conv_pos_embed_groups = None,
attn_dropout=0,
attn_flash = False,
num_register_tokens = 16,
p_drop_prob = 0.3, # p_drop in paper
frac_lengths_mask: Tuple[float, float] = (0.7, 1.),
condition_on_text = True
Expand Down Expand Up @@ -837,8 +877,9 @@ def __init__(
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
attn_dropout=attn_dropout,
attn_dropout= attn_dropout,
attn_flash = attn_flash,
num_register_tokens = num_register_tokens,
adaptive_rmsnorm = True,
adaptive_rmsnorm_cond_dim_in = time_hidden_dim
)
Expand Down

0 comments on commit 57f9ee4

Please sign in to comment.