This codebase enables you to train your own CLIP-like model on a single GPU by aligning pretrained vision models, such as DINOv2, with language models like NV-Embed-2. Our approach demonstrates that training only a lightweight alignment layer while keeping the backbones frozen is sufficient to bridge the vision and language representation spaces. Using just 23M web-collected and synthetic image-text pairs, we developed a foundational model called SAIL-L, which surpasses CLIP-L (LAION400M) in vairous retrieval tasks and ImageNet while also serving as a strong vision encoder for building Multimodal LLMs. We hope this codebase serve as a useful testbed for resource-limited community to explore multimodal representation learning in terms of new losses, new data combination as well as new modality-merge strategy.
- [2024/11/20] 🔥 SAIL Codebase Open-Sourced!
The repository now includes the complete pipeline for training data preparation and preprocessing, as well as the full training and evaluation codebase.
-
Clone the Repository
git clone https://github.com/lezhang7/SAIL.git pip install -r requirements.txt
-
Download the Alignment Layer Checkpoint You can download the pretrained alignment layer checkpoints from the links below:
Data Model Alignment Layer IN-1K I2T R@1 (MSCOCO) T2I R@1 (MSCOCO) I2T R@1 (Flickr30k) T2I R@1 (Flickr30k) Text (Winoground) Image (Winoground) Group (Winoground) Avg. (MMVP) 23M SAIL-L (GTE) download 65.4 54.1 42.7 80.8 68.9 34.0 13.25 8.75 22.2 23M SAIL-L (NV2) download 73.4 62.4 48.6 87.6 75.7 40.25 18.75 15.0 28.9 LAION-400M CLIP-L 72.7 59.7 43.0 87.6 70.2 30.5 11.5 8.75 20.0 -
Run the Model
from model import create_model from PIL import Image import torch # Path to the downloaded checkpoint checkpoint_path = "checkpoint/sail_dinov2l_nv2.pt" # Create the model, change the text_model to `Alibaba-NLP/gte-large-en-v1.5` if use sail_dinov2_gte model = create_model( text_model_name="nvidia/NV-Embed-v2", vision_model_name="facebook/dinov2-large", head_weights_path=checkpoint_path, target_dimension=1024, ) model.eval() # Set model to evaluation mode # Prepare images and texts image_processor = model.image_processor texts = ["a dog", "a cat"] dog_image = Image.open("asset/dog.jpg").convert("RGB") cat_image = Image.open("asset/cat.jpg").convert("RGB") images = image_processor(images=[dog_image, cat_image], return_tensors="pt") # Generate features and probabilities with torch.no_grad(): image_features = model.encode_image(images, normalize=True) text_features = model.encode_text(texts, text_list=texts, normalize=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) # Print the label probabilities print("Label probs:", text_probs)
The codebase builds upon OpenCLIP (for training SAIL) and LLaVA (for testing SAIL's vision encoder in MLLMs). Please ensure the necessary dependency packages for these frameworks are installed.
SAIL leverages high-quality, MLLM-enhanced captions for training, using datasets introduced in DreamLIP. To streamline this process, we provide a script for automated dataset preparation. Note that this process is time-intensive, as it involves handling 23M data samples.
cd data_preparation
bash download_mllm_enhanced_data.sh
-
Downloading Dataset:
The script downloads.csv
files containing image URLs and their corresponding captions for datasets such as CC3M, CC12M, and YFCC15M. -
Filtering Invalid Data:
Since some image URLs may have expired or the images may be corrupted, the downloaded images need to be filtered. Invalid samples must also be removed from the.csv
files. -
Manual Processing (Optional):
As each step can take a significant amount of time, we recommend running the commands manually based on your specific requirements.
Once the preprocessing is complete, update the dataset paths including annotation
and imagedir
field in data/data_config.py
:
DATADIR = {
'dreamclipcc3m': { 'annotation':f'{SCRATCH}/datasets/DownloadCC3M/cc3m_3long_1raw_captions_filterd.csv',
'imagedir':f'{SCRATCH}/datasets/DownloadCC3M'
}
}
The training framework of SAIL consists of two main steps: Pre-encoding and Alignment Tuning. This efficient framework allows us to align the representation space of large pretrained unimodal models (e.g., DINOv2 and NV2 models) on a single A100
GPU with a large batch size of 32,768
, requiring only approximately ~5 hours
of training during the alignment tuning stage.
We provide scripts to pre-encode image-text pairs into embeddings. The script will automatically download the required model checkpoints from either Hugging Face Transformers or TorchHub. By default, the vision model weights will be stored in the model/backbone_checkpoints
directory.
Note: Ensure your transformers
library version is >4.38.0
to support SDPA or FlashAttention implementations for some of the models.
- Review
scripts/encode.sh
for detailed configurations. Choose thetext_model
orvision_model
inscripts/encode.sh
. - Set the encoded domain to either
image
ortext
.- If encoding text, select the caption source:
raw_caption
or high-quality captionsshortIB_captions, shortSV_captions, shortLLA_captions longIB_captions, longSV_captions, longLLA_captions
- If encoding text, select the caption source:
- Execute the following command to encode the data. The embeddings will be saved in the
data/tensor_data
directory.
bash scripts/encode.sh
To train SAIL, specify the text and image embedding data by updating the text_embedding_list
and image_embedding_list
in scripts/sail_train.sh
.
- To enable multiple positive captions for contrastive loss, also update the
extra_text_embedding_list
. - Important: Ensure embeddings of the same modality are derived from the same model.
Run the following command to train the alignment layer:
bash scripts/sail_train.sh
To probe the alignment, execute:
bash scripts/alignment_probing.sh
We only save the alignment layer
checkpoint at ./logs/${output_name}
.
Evaluation scripts are provided in scripts/sail_eval.sh
.
- Set the
vision_model
,text_model
, andcheckpoint_path
inscripts/sail_eval.sh
.- Ensure that the vision and text models match the embedding data used to train the alignment layers.
- Prepare datasets
- Download MMVP_VLM and save it to
evaluation/MMVP_VLM
- ImageNet and Winoground will be automatically downloadad and processed
- Download MMVP_VLM and save it to
- Specify the task from
imagenetv1 winoground MMVP
insail_eval.sh
, then evaluate by running:
bash scripts/sail_eval.sh
The evaluation results will be saved to evaluation/eval_result/{task}
Open-vocabulary semantic segmentation Instructions please refer to here
SAIL significantly enhances SSL models, such as DINOv2, as vision encoders for MLLMs. Specifically, we replace the vision encoder in LLaVA-1.5 with the SAIL vision encoder, which consists of a DINOv2 backbone combined with an alignment layer. This additional alignment layer dramatically improves DINOv2's performance on MLLM tasks, even surpassing language-supervised CLIP vision encoders in certain tasks! We provide trained checkpoint at le723z/sail-llava-v1.5-7b.
VTune indicates whether the vision encoder is fine-tuned during the instruction tuning stage.
Model@224px | VTune | SEEDIMG | GQA | VizWiz | PoPE | TextVQA | MMB | VQAv2 |
---|---|---|---|---|---|---|---|---|
DINOv2-L | ✗ | 61.47 | 61.08 | 44.12 | 85.5 | 45.37 | 56.96 | 74.4 |
DINOv2-L | ✓ | 62.12 | 61.53 | 46.59 | 85.7 | 45.92 | 58.85 | 74.69 |
SAIL-L | ✓ | 65.43 | 62.63 | 50.00 | 86.16 | 46.53 | 60.14 | 76.77 |
CLIP-L/14 | ✗ | 64.05 | 61.58 | 48.87 | 85.74 | 54.56 | 63.06 | 75.32 |
CLIP-L/14 | ✓ | 64.15 | 61.54 | 49.93 | 85.73 | 54.18 | 64.12 | 76.36 |
We follow the LLaVA-1.5 training process, including pretraining and fine-tuning. To get started, please prepare the data following the instructions in the original codebase:
- Pretraining data
- Visual instruction tuning data
Then install the dependency packages following here for training llava-1.5. Recommend using cudatoolkit/12.1.1
for reproducibility.
Specify data_path
and image_folder
for the dataset configuration. Set vlhead_weights_path
and vision_tower
for SAIL-LLaVA as follows:
deepspeed llava_train/train_mem.py \
--deepspeed llava_train/zero2.json \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--version plain \
--data_path blip_laion_cc_sbu_558k.json \
--target_dimension 1024 \
--linear_type star \
--image_folder llava-v1.5-7b/pretrain_data \
--vision_tower facebook/dinov2-large \
--mm_projector_type mlp2x_gelu \
--tune_mm_mlp_adapter True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--num_train_epochs 1 \
--per_device_train_batch_size 64 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 4000 \
--save_total_limit 1 \
--learning_rate 0.001 \
--weight_decay 0.0 \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--fp16 False \
--bf16 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb \
--vlhead_weights_path logs/sail_l_nv2_merged23m/checkpoints/64.pt \
--tune_alignment_layer False \
--output_dir ./llava_checkpoints/sail_llava_pretrain
For fine-tuning, update the data_path
and other configurations. Use the following command:
deepspeed llava_train/train_mem.py \
--deepspeed llava_train/zero3.json \
--model_name_or_path lmsys/vicuna-7b-v1.5 \
--version v1 \
--data_path $dataset/llava-v1.5-7b/instruct_tuning_data/llava_v1_5_mix665k.json \
--image_folder $dataset/llava-v1.5-7b/instruct_tuning_data/data \
--target_dimension 1024 \
--linear_type star \
--vision_tower facebook/dinov2-large \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 2 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 400 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to wandb \
--vlhead_weights_path llava_checkpoints/sail_llava_pretrain/vlhead.bin \
--pretrain_mm_mlp_adapter llava_checkpoints/sail_llava_pretrain/mm_projector.bin \
--tune_alignment_layer True \
--unlock_vision_tower True \
--output_dir ./llava_checkpoints/sail_llava_finetune
We provide evaluation scripts in scripts/llava_eval_scripts
. Download the evaluation dataset following here and update the dataset path for each task bash script, then run the evaluation to test the model's performance.
This project is based on open_clip, DreamLIP and LLaVA , we appreicate their great work!