Skip to content

A ViT model that allows comparing text embeddings directly with image patch embeddings

License

Notifications You must be signed in to change notification settings

TinyVolt/multimodal-patch-embeddings

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

multimodal-patch-embeddings

Details

This repo contains the code for training and inference of distilled, and smaller, CLIP ViT model. The distilled model has 21.3 million parameters. The vision transformer uses a novel architecture which is much simpler. It does not have CLS embedding, neither does it have a projection layer at the end. Check out the src/mode.py/VisionTransformerExtraHead class to see the implementation. Check out the article here: https://www.tinyvolt.com/research/multimodal-patch-embeddings

Multimodal patch embeddings

What makes this model so special is that the embedding of each of the image patches is in the same embedding space as the final embedding. In fact, the final embedding is just a convex sum of the patch embeddings. This allows one to compare the text embedding with each of the 64 image patch embeddings.

Model output(s)

The ViT model maps an image to an embedding. By default, the model outputs the embedding (of shape B,512) and the probability distribution over the image patches (of shape (B,1,64) where 64 is the number of patches). However if you want to get the embedding for each image patch, you just need to pass an extra parameter, return_all_embeds, during inference:

import torch

# make sure you are in `src/` directory
from model import VisionTransformerExtraHead
from _types import vit_extended_same_norm_masked_28_args_16_heads_512_width as vit_multimodal_patch_args
vit = VisionTransformerExtraHead(**vit_multimodal_patch_args.model_dump())

x = torch.randn(1,3,224,224)
with torch.no_grad():
    y, attn = vit(x)
    print(y.shape, attn.shape)
    y, attn = vit(x, return_all_embeds=True)
    print(y.shape, attn.shape)

This will print the following:

torch.Size([1, 512]) torch.Size([1, 1, 64])
torch.Size([1, 64, 512]) torch.Size([1, 1, 64])

Directory structure

.
├── LICENSE
├── README.md
├── assets
│   ├── attention_comparison_no_cls
│   ├── patch_activations
│   └── search_comparison_no_cls
├── checkpoints
│   ├── checkpoint_epoch24_vit_extended_dim_2024-04-11_19-18-30.pt
│   └── checkpoint_epoch31_vit_extended_dim_same_norm_attn_mask_2024-04-27_20-29-33.pt
├── poetry.lock
├── pyproject.toml
└── src
    ├── _types.py
    ├── data.py
    ├── loss.py
    ├── main.py
    ├── model.py
    ├── notebooks
    │   ├── mm_patch_embed.ipynb
    │   └── vit_no_cls.ipynb
    └── utils.py

Downloading the checkpoints

  • Download the checkpoints from here and put them in the checkpoints folder.

The checkpoint checkpoint_epoch24_vit_extended_dim_2024-04-11_19-18-30.pt was not trained with the attention mask. It also does not enforce the patch embeddings to have the same norm before taking a convex sum. As a result, it does not need (and contain) the scale parameter defined in the VisionTransformerExtraHead class. To load this checkpoint, you can do something like so:

import torch

# make sure you are in `src/` directory
from model import VisionTransformerExtraHead
from _types import vit_extended_28_args_16_heads_512_width as vit_no_cls
vit = VisionTransformerExtraHead(**vit_no_cls.model_dump())

# during inference, make sure to set `same_norm` to `False`
x = torch.randn(1,3,224,224)
y, attn = vit(x, same_norm=False, return_all_embeds=False)

Please note that this checkpoint does not have multimodal patch embeddings.

Setting up

poetry install

Results

The below images show patch activations for different prompts.

About

A ViT model that allows comparing text embeddings directly with image patch embeddings

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published