Let's train vision transformers for cifar 10!
This is an unofficial and elementary implementation of An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
.
I use pytorch for implementation.
-
Added ConvMixer implementation. Really simple! (2021/10)
-
Added wandb train log to reproduce results. (2022/3)
-
Added CaiT and ViT-small. (2022/3)
-
Added SwinTransformers. (2022/3)
-
Added MLP mixer. (2022/6)
-
Changed default training settings for ViT.
python train_cifar10.py
# vit-patchsize-4
python train_cifar10.py --size 48
# vit-patchsize-4-imsize-48
python train_cifar10.py --patch 2
# vit-patchsize-2
python train_cifar10.py --net vit_small --n_epochs 400
# vit-small
python train_cifar10.py --net vit_timm
# train with pretrained vit
python train_cifar10.py --net convmixer --n_epochs 400
# train with convmixer
python train_cifar10.py --net mlpmixer --n_epochs 500 --aug --lr 1e-3
python train_cifar10.py --net cait --n_epochs 200
# train with cait
python train_cifar10.py --net swin --n_epochs 400
# train with SwinTransformers
python train_cifar10.py --net res18
# resnet18+randaug
Accuracy | Train Log | |
---|---|---|
ViT patch=2 | 80% | |
ViT patch=4 Epoch@200 | 80% | Log |
ViT patch=4 Epoch@500 | 88% | Log |
ViT patch=8 | 30% | |
ViT small | 80% | |
MLP mixer | 88% | |
CaiT | 80% | |
Swin-t | 90% | |
ViT small (timm transfer) | 97.5% | |
ViT base (timm transfer) | 98.5% | |
ConvMixerTiny(no pretrain) | 96.3% | Log |
resnet18 | 93% | |
resnet18+randaug | 95% | Log |