Skip to content

Commit

Permalink
when cross attending in look vit, make sure context tokens are normal…
Browse files Browse the repository at this point in the history
…ized
  • Loading branch information
lucidrains committed Jul 19, 2024
1 parent ec6c48b commit 4b2c00c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '1.7.2',
version = '1.7.3',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description=long_description,
Expand Down
16 changes: 13 additions & 3 deletions vit_pytorch/look_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
heads = 8,
dim_head = 64,
dropout = 0.,
cross_attend = False,
reuse_attention = False
):
super().__init__()
Expand All @@ -74,10 +75,13 @@ def __init__(
self.scale = dim_head ** -0.5
self.heads = heads
self.reuse_attention = reuse_attention
self.cross_attend = cross_attend

self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)

self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity()
self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity()

self.attend = nn.Softmax(dim = -1)
self.dropout = nn.Dropout(dropout)

Expand All @@ -99,7 +103,13 @@ def forward(
attn = None
):
x = self.norm(x)
context = default(context, x)

assert not (exists(context) ^ self.cross_attend)

if self.cross_attend:
context = self.norm_context(context)
else:
context = x

v = self.to_v(context)
v = self.split_heads(v)
Expand Down Expand Up @@ -179,8 +189,8 @@ def __init__(
layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout),
MLP(dim = dim, factor = mlp_factor, dropout = dropout),
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout),
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, reuse_attention = True),
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True),
Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True),
LayerNorm(dim),
MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout)
]))
Expand Down

0 comments on commit 4b2c00c

Please sign in to comment.