Simple and Fast Distillation of Diffusion Models
Zhenyu Zhou, Defang Chen, Can Wang, Chun Chen, Siwei Lyu
https://arxiv.org/abs/2409.19681
TL;DR: A simple and fast distillation of diffusion models that accelerates the fine-tuning up to 1000 times while performing high-quality image generation.
Abstract: Diffusion-based generative models have demonstrated their powerful performance across various tasks, but this comes at a cost of the slow sampling speed. To achieve both efficient and high-quality synthesis, various distillation-based accelerated sampling methods have been developed recently. However, they generally require time-consuming fine tuning with elaborate designs to achieve satisfactory performance in a specific number of function evaluation (NFE), making them difficult to employ in practice. To address this issue, we propose Simple and Fast Distillation (SFD) of diffusion models, which simplifies the paradigm used in existing methods and largely shortens their fine-tuning time up to 1000 times. We begin with a vanilla distillation-based sampling method and boost its performance to state of the art by identifying and addressing several small yet vital factors affecting the synthesis efficiency and quality. Our method can also achieve sampling with variable NFEs using a single distilled model. Extensive experiments demonstrate that SFD strikes a good balance between the sample quality and fine-tuning costs in few-step image generation task. For example, SFD achieves 4.53 FID (NFE=2) on CIFAR-10 with only 0.64 hours of fine-tuning on a single NVIDIA A100 GPU.
- This codebase mainly refers to the codebase of EDM. To install the required packages, please refer to the EDM codebase.
- This codebase supports the pre-trained diffusion models from EDM, LDM and Stable Diffusion. When you want to load the pre-trained diffusion models from these codebases, please refer to the corresponding codebases for package installation.
Run the commands in launch.sh for training, sampling and evaluation with recommended settings.
You can find the descriptions to the main parameters in the next section.
The required models will be downloaded at "./src/dataset_name"
automatically.
We use 4 A100 GPUs for all experiments. You can change the batch size based on your devices.
Note: num_steps
is the number of timestamps (sampling steps + 1). num_steps=4
hence refers to 3 sampling steps. The use of AFS saves 1 step, so (num_steps=4
with afs=False
) equals 3 sampling steps and (num_steps=4
with afs=True
)equals 2 sampling steps.
# Train a 2-NFE SFD (useable for EDM models trained on cifar10, ffhq, afvqv2 and imagenet64)
torchrun --standalone --nproc_per_node=4 --master_port=12345 train.py \
--dataset_name="cifar10" --total_kimg=200 --batch=128 --lr=5e-5 \
--num_steps=4 --M=3 --afs=True --sampler_tea="dpmpp" --max_order=3 --predict_x0=True --lower_order_final=True \
--schedule_type="polynomial" --schedule_rho=7 --use_step_condition=False --is_second_stage=False
# Train SFD-v (NFE-variable version, allow sampling for num_steps within 4 to 7, a.k.a. NFE within 2 to 5, using one model)
torchrun --standalone --nproc_per_node=4 --master_port=12345 train.py \
--dataset_name="cifar10" --total_kimg=200 --batch=128 --lr=5e-5 \
--num_steps=4 --M=3 --afs=True --sampler_tea="dpmpp" --max_order=3 --predict_x0=True --lower_order_final=True \
--schedule_type="polynomial" --schedule_rho=7 --use_step_condition=True --is_second_stage=False
After training, the distilled SFD model will be saved at "./exps" with a five digit experiment number (e.g. 00000).
The settings for sampling are stored in the model file. You can perform accelerated sampling with SFD by giving the file path or the experiment digit number (e.g. 0) to --model_path
.
# Sample 50k images using SFD for FID evaluation
torchrun --standalone --nproc_per_node=4 --master_port=12345 sample.py \
--dataset_name='cifar10' --model_path=0 --seeds='0-49999' --batch=256
# Sample 50k images using SFD-v for FID evaluation
# When use_step_condition=True is used for distillation, set a specific num_steps during sampling
torchrun --standalone --nproc_per_node=4 --master_port=12345 sample.py \
--dataset_name='cifar10' --model_path=0 --seeds='0-49999' --batch=256 --num_steps=4
To compute Fréchet inception distance (FID), compare the generated images against the dataset reference statistics:
# FID evaluation
python fid.py calc --images="path/to/generated/images" --ref="path/to/fid/stat"
We also provide a script for calculating precision, recall, density and coverage. The reference images for CIFAR-10 (cifar10-32x32.zip) can be found here.
# Precision, recall, density and coverage
python prdc.py calc --images="path/to/generated/images" --images_ref="path/to/reference/images"
Paramater | Default | Description |
---|---|---|
dataset_name | None | One in ['cifar10', 'ffhq', 'afhqv2', 'imagenet64', 'lsun_bedroom', 'lsun_bedroom_ldm', 'ms_coco'] |
total_kimg | 200 | How many sampling trajectories to be trained (x1000) |
batch | 128 | Total batch size |
lr | 5e-5 | Learning rate |
num_steps | 4 | Number of timestamps for the student solver |
M | 3 | How many timestamps to be inserted into every two adjacent timestamps in the original schedule. The |
afs | True | Whether to use AFS which saves the first model evaluation. |
sampler_tea | 'dpmpp' | Teacher solver. One in ['dpm', 'dpmpp', 'euler', 'ipndm', 'heun'] |
max_order | None | Option for multi-step solvers. 1<=max_order<=4 for iPNDM 1<=max_order<=3 for DPM-Solver++ |
predict_x0 | True | Option for DPM-Solver++. Whether to use the data prediction formulation |
lower_order_final | True | Option for DPM-Solver++. Whether to lower the order at the final stages of sampling |
schedule_type | 'polynomial' | Time discretization schedule. One in ['polynomial', 'logsnr', 'time_uniform', 'discrete'] |
schedule_rho | 7 | Time step exponent. Need to be specified when schedule_type in ['polynomial', 'time_uniform', 'discrete'] |
use_step_condition | False | Whether to add step condition into the model to obtain a NFE-variable model |
is_second_stage | False | Whether to perform second-stage distillation to obtain a 1-NFE model |
model_path | None | When is_second_stage=True, should be the path to a SFD or SFD-v model |
We perform sampling on a variaty of pre-trained diffusion models from different codebases including EDM, LDM and Stable Diffusion. Supported pre-trained models are listed below:
Codebase | dataset_name | Resolusion | Pre-trained Models | Description |
---|---|---|---|---|
EDM | cifar10 | 32 | edm-cifar10-32x32-uncond-vp.pkl | |
EDM | ffhq | 64 | edm-ffhq-64x64-uncond-vp.pkl | |
EDM | afhqv2 | 64 | edm-afhqv2-64x64-uncond-vp.pkl | |
EDM | imagenet64 | 64 | edm-imagenet-64x64-cond-adm.pkl | |
LDM | lsun_bedroom_ldm | 256 | lsun_bedrooms.zip | Latent-space |
LDM | ffhq_ldm | 256 | ffhq.zip | Latent-space |
Stable Diffusion | ms_coco | 512 | stable-diffusion-v1-5 | Classifier-free-guidance |
For facilitating the FID evaluation of diffusion models, we provide our FID statistics of various datasets. They are collected on the Internet or made by ourselves with the guidance of the EDM codebase.
You can compute the reference statistics for your own datasets as follows:
python fid.py ref --data=path/to/my-dataset.zip --dest=path/to/save/my-dataset.npz
If you find this repository useful, please consider citing the following paper:
@article{chen2024trajectory,
title={On the Trajectory Regularity of ODE-based Diffusion Sampling},
author={Chen, Defang and Zhou, Zhenyu and Wang, Can and Shen, Chunhua and Lyu, Siwei},
journal={arXiv preprint arXiv:2405.11326},
year={2024}
}