Dataset distillation (DD) generates small synthetic datasets that can efficiently train deep networks with a limited amount of memory and compute. Despite the success of DD methods for supervised learning, DD for self-supervised pre-training of deep models has remained unaddressed. Pre-training on unlabeled data is crucial for efficiently generalizing to downstream tasks with limited labeled data. In this work, we propose the first effective DD method for SSL pre-training. First, we show, theoretically and empirically, that naïve application of supervised DD methods to SSL fails, due to the high variance of the SSL gradient. Then, we address this issue by relying on insights from knowledge distillation (KD) literature. Specifically, we train a small student model to match the representations of a larger teacher model trained with SSL. Then, we generate a small synthetic dataset by matching the training trajectories of the student models. As the KD objective has considerably lower variance than SSL, our approach can generate synthetic datasets that can successfully pre-train high-quality encoders. Through extensive experiments, we show that our distilled sets lead to up to 13% higher accuracy than prior work, on a variety of downstream tasks, in the presence of limited labeled data.
git clone [email protected]:jiayini1119/MKDT.git
pip install -r requirements.txt
We obtained the teacher model trained with Barlow Twins using the checkpoint provided in the KRRST. Download and save the checkpoints under the repository /krrst_teacher_ckpt
.
We obtained the teacher model trained with SimCLR using the checkpoint provided in the SAS.
To get the target representation:
python get_target_rep.py --dataset {CIFAR10/CIFAR100/Tiny} --model {model: ConvNetD4 for TinyImageNet and ConvNet for other datasets} --ssl_algorithm {barlow_twins/simclr} --data_path {dataset path} --result_dir {directory to store the target representations} --device {device}
By default, the target representations will be saved in /{result_dir}_{ssl_algorithm}/{dataset}_target_rep_train.pt
.
Run the following sripts to get expert trajectories:
CIFAR 10: commands/buffer/{ssl_algorithm}/c10_get_trajectory.sh
CIFAR 100: commands/buffer/{ssl_algorithm}/c10_get_trajectory.sh
Tiny ImageNet: commands/buffer/{ssl_algorithm}/tiny_get_trajectory.sh
The buffers will be saved in the directory buffer/{ssl_algorithm}/{dataset}/{model}
.
To obtain the high loss subset for distilled dataset initialization:
python get_target_rep.py --dataset {CIFAR10/CIFAR100/Tiny} --data_path {dataset path} --model {model} --num_buffers {number of buffers} --ssl_algo {Algorithm to train the ssl} --train_labels_path {path to the target representation of the dataset} --batch_train {batch size of the train dataset} --device {device}
For example,
python get_target_rep.py --dataset CIFAR100 --data_path /home/data --model ConvNet --num_buffers 100 --ssl_algo barlow_twins --train_labels_path /home/jennyni/MKDT/target_rep/barlow_twins/CIFAR100_target_rep_train.pt
Run the following sripts to distill the dataset (SSL algorithm using barlow twins):
CIFAR 10
2 percent: commands/distill/barlow_twins/CIFAR10/2_per.sh
5 percent: commands/distill/barlow_twins/CIFAR10/5_per.sh
CIFAR 100
2 percent: commands/distill/barlow_twins/CIFAR100/2_per.sh
5 percent: commands/distill/barlow_twins/CIFAR100/5_per.sh
Tiny ImageNet
2 percent: commands/distill/barlow_twins/Tiny/2_per.sh
5 percent: commands/distill/barlow_twins/Tiny/5_per.sh
The following scripts contains the commands to run the evaluation for different subsets (e.g., MKDT, random, KRRST).
Append --subset_frac {0.01/0.05/0.1/0.5}
to the command to evaluate the datasets using different evaluation subset fractions
CIFAR 10
2 percent: commands/eval/c10_2per.sh --subset_frac {0.01/0.05/0.1/0.5}
5 percent: commands/eval/c10_5per.sh --subset_frac {0.01/0.05/0.1/0.5}
CIFAR 100
2 percent: commands/eval/c100_2per.sh --subset_frac {0.01/0.05/0.1/0.5}
5 percent: commands/eval/c100_5per.sh --subset_frac {0.01/0.05/0.1/0.5}
Tiny ImageNet
2 percent: commands/eval/tiny_2per.sh --subset_frac {0.01/0.05/0.1/0.5}
5 percent: commands/eval/tiny_5per.sh --subset_frac {0.01/0.05/0.1/0.5}
You can visualize the tables comparing different subset results for a dataset using commands/exp_plotting.ipynb
.
The code is based on the following repositories.
https://github.com/GeorgeCazenavette/mtt-distillation
https://github.com/db-Lee/selfsup_dd
@inproceedings{joshi2025kd,
title={Dataset Distillation via Knowledge Distillation: Towards Efficient Self-Supervised Pre-Training of Deep Networks},
author={Joshi, Siddharth and Ni, Jiayi and Mirzasoleiman, Baharan},
booktitle={Proceedings of the International Conference on Learning Representations (ICLR)},
year={2025}
}