Skip to content

Latest commit

 

History

History
52 lines (37 loc) · 2.77 KB

pretrained_models.md

File metadata and controls

52 lines (37 loc) · 2.77 KB

Pretrained Models

The following checkpoints are required to run the code:

python tools/prepare_checkpoints.py pytorch_fid

python tools/prepare_checkpoints.py lpips
python tools/convert_checkpoints.py lpips pretrained/lpips/vgg.pth

python tools/prepare_checkpoints.py torchvision --weights .VGG16_Weights.DEFAULT

The following checkpoints are the IU models that are used in the paper:

python tools/prepare_checkpoints.py clip --weights ViT-B/16
python tools/prepare_checkpoints.py dino
python tools/prepare_checkpoints.py torchvision --weights .ViT_B_16_Weights.DEFAULT
python tools/prepare_checkpoints.py mae

The following checkpoints are used to initialize the AR proposal networks:

python tools/prepare_checkpoints.py huggingface

VQGAN and VQ-KD checkpoints from their original repo can be loaded by our code, after conversion:

python tools/prepare_checkpoints.py taming_transformers
python tools/convert_checkpoints.py taming_transformers pretrained/taming-transformers/vqgan_imagenet_f16_1024.pth --check configs/vqgan/1024_imagenet_ddp.py
python tools/convert_checkpoints.py taming_transformers pretrained/taming-transformers/vqgan_imagenet_f16_16384.pth --check configs/vqgan/16384_dd2_aglwg075_imagenet_ddp.py

python tools/prepare_checkpoints.py beitv2
python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_clip.pth
python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_clip.pth --check configs/vqkd/clip_8192_imagenet_ddp.py --suffix .converted.with_decoder --with-decoder
python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_dino.pth
python tools/convert_checkpoints.py beitv2 pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_dino.pth --check configs/vqkd/dino_8192_imagenet_ddp.py --suffix .converted.with_decoder --with-decoder

After generating the FID cache, you can run the following command to validate the pretrained models:

auto_torchrun -m vq.test vqgan/16384_dd2_aglwg075_imagenet_ddp configs/vqgan/16384_dd2_aglwg075_imagenet_ddp.py --load-model-from pretrained/taming-transformers/vqgan_imagenet_f16_16384.pth.converted --visual pred_image
# {'lpips_loss': 0.28323277831077576, 'l1_image_loss': 0.06811775267124176, 'mse_image_loss': 0.013179616071283817, 'psnr': 19.970359802246094, 'ssim': 0.5023356676101685, 'fid': 4.980832106065748, 'codebook_usage': 0.059326171875, 'codebook_ppl': 6.812368392944336}

auto_torchrun -m vq.test vqkd/clip_8192_imagenet_ddp configs/vqkd/clip_8192_imagenet_ddp.py --load-model-from pretrained/beitv2/vqkd_encoder_base_decoder_1x768x12_clip.pth.converted.with_decoder
# {'cosine_embedding_r_loss': 0.16431047022342682, 'codebook_usage': 1.0, 'codebook_ppl': 8.94822883605957}