This repository contains the pyotrch implementation of Pretext invariant representation learning (PIRL) algorithm on STL10 dataset. PIRL was originally introduced by Misra et al, publication of which can be found here.
Pretext invariant representation learning (PIRL) is a self supervised learing algorithm that exploits contrastive learning to learn visual representations such that original and transformed version of the same image have similar representations, while being different from that of other images, thus achieving invariance to the transformation.
In their paper, authors have primarily focused on jigsaw puzzles transformation.
The CNN used for representation learning is trained using NCE (Noise Contrastive Estimation) technique,
NCE models the porbability of event that (I, I_t) (original and transformed image) originate from the same
data distribution, I.e.
Where s(., .) is cosine similarity between v_i and v_i_t, deep representations for original and transformed image respectively.
While, the final NCE loss is given as:
where f(.) and g(.) are linear function heads.
Instead of using NCE loss, for this implementation, optimization process would directly aim to minimize the negative log of probability described in the first equation above (with inputs as f(v_i) and g(v_i_t))
The implementation uses STL10 dataset, which can be downloaded from here
1. Download raw data from above link to ./raw_stl10/
2. Run stl10_data_load.py. This will save three directories train, test and unlabelled in ./stl10_data/
- Run script pirl_stl_train_test.py for unsupervised (self supervised learning), example
python pirl_stl_train_test.py --model-type res18 --batch-size 128 --lr 0.1 --experiment-name exp
- Run script train_stl_after_ssl.py for fine tuning model parameters obtained from self supervised learning, example
python train_stl_after_ssl.py --model-type res18 --batch-size 128 --lr 0.1 --patience-for-lr-decay 4 --full-fine-tune True --pirl-model-name <relative_model_path from above run>
After training the CNN model in PIRL manner, to evaluate how well learnt model weights transfer to classification problem in limited dataset scenario, following experiments were performed.
Fine tuning strategy | Val Classification Accuracy |
---|---|
Only softmax layer is fine tuned | 50.50 |
Full model is fine tuned | 67.87 |
- PIRL paper: https://arxiv.org/abs/1912.01991
- STL 10 dataset: http://ai.stanford.edu/~acoates/stl10/
- Data loading code for STL 10: https://github.com/mttk/STL10