Paper | Project Page | HF Demo | 中文解读
Struggling with 'index collapse' in vector quantization? Discover OptVQ, a solution designed to maximize codebook utilization and enhance reconstruction quality.
| [2024-12-16] We release the training code of OptVQ.
| [2024-11-26] We release the pre-trained models of OptVQ.
We conduct image reconstruction experiments on the ImageNet dataset, and the quantitative comparison is shown below:
Model | Latent Size | #Tokens | From Scratch | SSIM↑ | PSNR ↑ | LPIPS↓ | rFID↓ |
---|---|---|---|---|---|---|---|
taming-VQGAN | 16 × 16 | 1,024 | √ | 0.521 | 23.30 | 0.195 | 6.25 |
MaskGiT-VQGAN | 16 × 16 | 1,024 | √ | - | - | - | 2.28 |
Mo-VQGAN | 16 × 16 × 4 | 1,024 | √ | 0.673 | 22.42 | 0.113 | 1.12 |
TiTok-S-128 | 128 | 4,096 | × | - | - | - | 1.71 |
ViT-VQGAN | 32 × 32 | 8,192 | √ | - | - | - | 1.28 |
taming-VQGAN | 16 × 16 | 16,384 | √ | 0.542 | 19.93 | 0.177 | 3.64 |
RQ-VAE | 8 × 8 × 16 | 16,384 | √ | - | - | - | 1.83 |
VQGAN-LC | 16 × 16 | 100,000 | × | 0.589 | 23.80 | 0.120 | 2.62 |
OptVQ (ours) | 16 × 16 × 4 | 16,384 | √ | 0.717 | 26.59 | 0.076 | 1.00 |
OptVQ (ours) | 16 × 16 × 8 | 16,384 | √ | 0.729 | 27.57 | 0.066 | 0.91 |
We visualize the process of OptVQ and Vanilla VQ on a two-dimensional toy example. The left figure with red points represents the baseline (Vanilla VQ), and the right figure with green points represents the proposed method (OptVQ).
Please install the dependencies by running the following command:
# install the dependencies
pip install -r requirements.txt
# install the faiss-gpu package via conda
conda install -c pytorch -c nvidia faiss-gpu=1.8.0
# install the optvq package
pip install -e .
If you want to use our quantizer in your project, you can follow the code below:
# Given the input tensor x, the quantizer will output the quantized tensor x_quant, the loss, and the indices.
from optvq.models.quantizer import VectorQuantizerSinkhorn
quantizer = VectorQuantizerSinkhorn(n_e=1024, e_dim=256, num_head=1)
x_quant, loss, indices = quantizer(x)
Please download the pre-trained models from the following links:
Model | Link (Tsinghua) | Link (Hugging Face) |
---|---|---|
OptVQ (16 x 16 x 4) | Download | Download |
OptVQ (16 x 16 x 8) | Download | Download |
You can load from the Hugging Face model hub by running the following code:
# Example: load the OptVQ with 16 x 16 x 4
from optvq.models.vqgan_hf import VQModelHF
model = VQModelHF.from_pretrained("BorelTHU/optvq-16x16x4")
You can also write the following code to load the pre-trained model locally:
# Example: load the OptVQ with 16 x 16 x 4
from optvq.utils.init import initiate_from_config_recursively
from omegaconf import OmegaConf
import torch
config = OmegaConf.load("configs/optvq.yaml")
model = initiate_from_config_recursively(config.autoencoder)
params = torch.load(..., map_location="cpu")
model.load_state_dict(params["model"])
After loading the model, you can perform inference (reconstruction):
# load the dataset
dataset = ... # the input should be normalized to [-1, 1]
data = dataset[...] # size: (BS, C, H, W)
# reconstruct the input
with torch.no_grad():
quant, *_ = model.encode(data)
recon = model.decode(quant)
To evaluate the model, you can use the following code:
config_path=configs/imagenet/optvq_256_f16_h4.yaml
log_dir=<path-to-log-folder>
resume=<path-to-checkpoint-folder>
python eval.py --config $config_path --log_dir $log_dir --resume $resume --is_distributed
We train the OptVQ (16 × 16 × 4) model on the ImageNet dataset with 8 NVIDIA 4090 GPUs for 50 epochs (around 8 days). The training script is as follows:
config_path=configs/imagenet/optvq_256_f16_h4.yaml
log_dir=<path-to-log-folder>
python train.py --config $config_path --log_dir $log_dir --is_distributed --lr 2e-6
- Image Generation: Due to limited computation resources, the generation experiment on ImageNet with MaskGiT will cost around more than 1 month on 8 NVIDIA 4090 GPUs. If you are interested in this experiment, please contact us.
- High Compression Ratio: We also plan to explore how to train a tokenizer with high compression ratio (e.g., f=32). If you have any ideas, please feel free to contact us.
If you find this work useful, please consider citing it.
@article{zhang2024preventing,
title = {Preventing Local Pitfalls in Vector Quantization via Optimal Transport},
author = {Borui Zhang and Wenzhao Zheng and Jie Zhou and Jiwen Lu},
journal = {ArXiv},
year = {2024},
volume = {abs/2412.15195}
}