From 4b2c00cb630ee64ab0eaa6a9d455c1670183a0fc Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 19 Jul 2024 10:23:12 -0700 Subject: [PATCH] when cross attending in look vit, make sure context tokens are normalized --- setup.py | 2 +- vit_pytorch/look_vit.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index c815ca8..eea2d68 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '1.7.2', + version = '1.7.3', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', long_description=long_description, diff --git a/vit_pytorch/look_vit.py b/vit_pytorch/look_vit.py index 3b7b27b..1651796 100644 --- a/vit_pytorch/look_vit.py +++ b/vit_pytorch/look_vit.py @@ -66,6 +66,7 @@ def __init__( heads = 8, dim_head = 64, dropout = 0., + cross_attend = False, reuse_attention = False ): super().__init__() @@ -74,10 +75,13 @@ def __init__( self.scale = dim_head ** -0.5 self.heads = heads self.reuse_attention = reuse_attention + self.cross_attend = cross_attend self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) self.norm = LayerNorm(dim) if not reuse_attention else nn.Identity() + self.norm_context = LayerNorm(dim) if cross_attend else nn.Identity() + self.attend = nn.Softmax(dim = -1) self.dropout = nn.Dropout(dropout) @@ -99,7 +103,13 @@ def forward( attn = None ): x = self.norm(x) - context = default(context, x) + + assert not (exists(context) ^ self.cross_attend) + + if self.cross_attend: + context = self.norm_context(context) + else: + context = x v = self.to_v(context) v = self.split_heads(v) @@ -179,8 +189,8 @@ def __init__( layers.append(ModuleList([ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = dropout), MLP(dim = dim, factor = mlp_factor, dropout = dropout), - Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout), - Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, reuse_attention = True), + Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True), + Attention(dim = dim, dim_head = cross_attn_dim_head, heads = cross_attn_heads, dropout = dropout, cross_attend = True, reuse_attention = True), LayerNorm(dim), MLP(dim = dim, factor = highres_mlp_factor, dropout = dropout) ]))