- CIFAR: download cifar dataset to folder
~/datasets/cifar
(you may specify this in configuration files). - ImageNet: download ImageNet dataset to folder
~/datasets/ILSVRC2012
and pre-process with this script. - Tiny-ImageNet: download Tiny-ImageNet dataset to folder
~/datasets/tiny-imagenet-200
.
- torch>=1.10.2
- torchvision>=0.11.3
- tqdm
- timm
- tensorboard
- scipy
- PyYAML
- pandas
- numpy
In dir config
, we provide some configurations for training, including CIFAR10/100, Tiny-ImageNet and ImageNet.
The following script will start training NesT-Ti-S (with heavy constituent heads)
from scratch on CIFAR100.
export CUDA_VISIBLE_DEVICES=0,1
port=9874
python dist_engine.py \
--num-nodes 1 \
--rank 0 \
--master-url tcp://localhost:${port} \
--backend nccl \
--multiprocessing \
--file-name-cfg cls \
--cfg-filepath config/cifar100/nest/sparse-nest-tiny.yaml \
--log-dir run/cifar100/nest/sparse-nest-tiny \
--worker worker
You can train your selected network and dataset by determining the config file in
--cfg-filepath
. We provid all the models used in paper and define them in\vit_mutual\models\__init__.py
. The prefix of the network is important. Specifically,"sparse"
indicates the use of single-axis constituents, while"global-sparse"
indicates employing dual-axis constituents."small"
indicates slim heads, and without this prefix means heavy heads.
You can run visualize_attention.py
to visualize attention, and hessian.py
to generate Hessian Max Eigenvalues by the power iteration algorithm. hessian.py
requires PyHessian library that can be installed from pip
pip install pyhessian