Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeCamp2023-240]Adding support for Consistency Models #2086

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
87 changes: 87 additions & 0 deletions configs/consistency_models/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Consistency Models (ICML'2023)

> [Consistency Models](https://arxiv.org/abs/2303.01469)

> **Task**: conditional

<!-- [ALGORITHM] -->

## Abstract

<!-- [ABSTRACT] -->

Diffusion models have significantly advanced the fields of image, audio, and video generation, but they depend on an iterative sampling process that causes slow generation. To overcome this limitation, we propose consistency models, a new family of models that generate high quality samples by directly mapping noise to data. They support fast one-step generation by design, while still allowing multistep sampling to trade compute for sample quality. They also support zero-shot data editing, such as image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks. Consistency models can be trained either by distilling pre-trained diffusion models, or as standalone generative models altogether. Through extensive experiments, we demonstrate that they outperform existing distillation techniques for diffusion models in one- and few-step sampling, achieving the new state-of-the-art FID of 3.55 on CIFAR-10 and 6.20 on ImageNet 64x64 for one-step generation. When trained in isolation, consistency models become a new family of generative models that can outperform existing one-step, non-adversarial generative models on standard benchmarks such as CIFAR-10, ImageNet 64x64 and LSUN 256x256.

<div align="center">
<img src="https://github.com/xiaomile/mmagic/assets/14927720/1586f0c0-8def-4339-b898-470333a26125" width=800>
</div>

## Pre-trained models

| Model | Dataset | Conditional | Download |
| :-------------------------------------------------------------------------------------------: | :--------: | :---------: | :------: |
| [onestep on ImageNet-64](./consistency_models_8xb256-imagenet1k-onestep-64x64.py) | imagenet1k | yes | - |
| [multistep on ImageNet-64](./consistency_models_8xb256-imagenet1k-multistep-64x64.py) | imagenet1k | yes | - |
| [onestep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py) | LSUN | no | - |
| [multistep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py) | LSUN | no | - |
| [onstep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-onestep-256x256.py) | LSUN | no | - |
| [multistep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-multistep-256x256.py) | LSUN | no | - |

You can also download checkpoints which is the main models in the paper to local machine and deliver the path to 'model_path' before infer.
Here are the download links for each model checkpoint:

- EDM on ImageNet-64: [edm_imagenet64_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_imagenet64_ema.pt)
- CD on ImageNet-64 with l2 metric: [cd_imagenet64_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_l2.pt)
- CD on ImageNet-64 with LPIPS metric: [cd_imagenet64_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_lpips.pt)
- CT on ImageNet-64: [ct_imagenet64.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_imagenet64.pt)
- EDM on LSUN Bedroom-256: [edm_bedroom256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_bedroom256_ema.pt)
- CD on LSUN Bedroom-256 with l2 metric: [cd_bedroom256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_l2.pt)
- CD on LSUN Bedroom-256 with LPIPS metric: [cd_bedroom256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_lpips.pt)
- CT on LSUN Bedroom-256: [ct_bedroom256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_bedroom256.pt)
- EDM on LSUN Cat-256: [edm_cat256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_cat256_ema.pt)
- CD on LSUN Cat-256 with l2 metric: [cd_cat256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_l2.pt)
- CD on LSUN Cat-256 with LPIPS metric: [cd_cat256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_lpips.pt)
- CT on LSUN Cat-256: [ct_cat256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_cat256.pt)

## quick start

**Infer**

<details>
<summary>Infer Instructions</summary>

You can use the following commands to infer with the model.

```shell
# onestep
python demo\mmagic_inference_demo.py \
--model-name consistency_models \
--model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
--result-out-dir demo_consistency_model.jpg

# multistep
python demo\mmagic_inference_demo.py \
--model-name consistency_models \
--model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py \
--result-out-dir demo_consistency_model.jpg

# conditional
python demo\mmagic_inference_demo.py \
--model-name consistency_models \
--model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
--label 145 \
--result-out-dir demo_consistency_model.jpg
```

</details>

# Citation

```bibtex
@article{song2023consistency,
title={Consistency Models},
author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
journal={arXiv preprint arXiv:2303.01469},
year={2023},
}
```
88 changes: 88 additions & 0 deletions configs/consistency_models/README_zh-CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Consistency Models (ICML'2023)

> [Consistency Models](https://arxiv.org/abs/2303.01469)

> **任务**: 条件生成

<!-- [ALGORITHM] -->

## 摘要

<!-- [ABSTRACT] -->

扩散模型在图像、音频和视频生成领域取得了显著的进展,但它们依赖于迭代采样过程,导致生成速度较慢。为了克服这个限制,我们提出了一种新的模型家族——一致性模型,通过直接将噪声映射到数据来生成高质量的样本。它们通过设计支持快速的单步生成,同时仍然允许多步采样以在计算和样本质量之间进行权衡。它们还支持零样本数据编辑,如图像修补、上色和超分辨率,而不需要在这些任务上进行显式训练。一致性模型可以通过蒸馏预训练的扩散模型或作为独立的生成模型进行训练。通过大量实验证明,它们在单步和少步采样方面优于现有的扩散模型蒸馏技术,实现了 CIFAR-10 上的新的最先进 FID(Fréchet Inception Distance)为 3.55,ImageNet 64x64 上为 6.20 的结果。当独立训练时,一致性模型成为一种新的生成模型家族,在 CIFAR-10、ImageNet 64x64 和 LSUN 256x256 等标准基准测试上可以优于现有的单步非对抗性生成模型。

<div align="center">
<img src="https://github.com/xiaomile/mmagic/assets/14927720/1586f0c0-8def-4339-b898-470333a26125" width=800>
</div>

## 预训练模型

| Model | Dataset | Conditional | Download |
| :-------------------------------------------------------------------------------------------: | :--------: | :---------: | :------: |
| [onestep on ImageNet-64](./consistency_models_8xb256-imagenet1k-onestep-64x64.py) | imagenet1k | yes | - |
| [multistep on ImageNet-64](./consistency_models_8xb256-imagenet1k-multistep-64x64.py) | imagenet1k | yes | - |
| [onestep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-onestep-256x256.py) | LSUN | no | - |
| [multistep on LSUN Bedroom-256](./consistency_models_8xb32-LSUN-bedroom-multistep-256x256.py) | LSUN | no | - |
| [onstep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-onestep-256x256.py) | LSUN | no | - |
| [multistep on LSUN Cat-256](./consistency_models_8xb32-LSUN-cat-multistep-256x256.py) | LSUN | no | - |

你也可以在进行推理前先把论文中主要模型的权重下载到本地的机器上并将权重路径传给'model_path'。
以下是每个模型权重的下载链接:

- EDM on ImageNet-64: [edm_imagenet64_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_imagenet64_ema.pt)
- CD on ImageNet-64 with l2 metric: [cd_imagenet64_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_l2.pt)
- CD on ImageNet-64 with LPIPS metric: [cd_imagenet64_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_imagenet64_lpips.pt)
- CT on ImageNet-64: [ct_imagenet64.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_imagenet64.pt)
- EDM on LSUN Bedroom-256: [edm_bedroom256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_bedroom256_ema.pt)
- CD on LSUN Bedroom-256 with l2 metric: [cd_bedroom256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_l2.pt)
- CD on LSUN Bedroom-256 with LPIPS metric: [cd_bedroom256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_bedroom256_lpips.pt)
- CT on LSUN Bedroom-256: [ct_bedroom256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_bedroom256.pt)
- EDM on LSUN Cat-256: [edm_cat256_ema.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/edm_cat256_ema.pt)
- CD on LSUN Cat-256 with l2 metric: [cd_cat256_l2.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_l2.pt)
- CD on LSUN Cat-256 with LPIPS metric: [cd_cat256_lpips.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/cd_cat256_lpips.pt)
- CT on LSUN Cat-256: [ct_cat256.pt](https://download.openxlab.org.cn/models/xiaomile/consistency_models/weight/ct_cat256.pt)
-

## 快速开始

**推理**

<details>
<summary>推理说明</summary>

您可以使用以下命令来使用该模型进行推理:

```shell
# 一步生成
python demo\mmagic_inference_demo.py \
--model-name consistency_models \
--model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
--result-out-dir demo_consistency_model.jpg

# 多步生成
python demo\mmagic_inference_demo.py \
--model-name consistency_models \
--model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-multistep-64x64.py \
--result-out-dir demo_consistency_model.jpg

# 条件控制生成
python demo\mmagic_inference_demo.py \
--model-name consistency_models \
--model-config configs/consistency_models/consistency_models_8xb256-imagenet1k-onestep-64x64.py \
--label 145 \
--result-out-dir demo_consistency_model.jpg
```

</details>

# Citation

```bibtex
@article{song2023consistency,
title={Consistency Models},
author={Song, Yang and Dhariwal, Prafulla and Chen, Mark and Sutskever, Ilya},
journal={arXiv preprint arXiv:2303.01469},
year={2023},
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = ['../_base_/default_runtime.py']

denoiser_config = dict(
type='KarrasDenoiser',
sigma_data=0.5,
sigma_max=80.0,
sigma_min=0.002,
weight_schedule='uniform',
)

unet_config = dict(
type='ConsistencyUNetModel',
in_channels=3,
model_channels=192,
num_res_blocks=3,
dropout=0.0,
channel_mult='',
use_checkpoint=False,
use_fp16=False,
num_head_channels=64,
num_heads=4,
num_heads_upsample=-1,
resblock_updown=True,
use_new_attention_order=False,
use_scale_shift_norm=True)

model = dict(
type='ConsistencyModel',
unet=unet_config,
denoiser=denoiser_config,
attention_resolutions='32,16,8',
batch_size=4,
class_cond=True,
generator='determ',
image_size=64,
learn_sigma=False,
model_path='https://download.openxlab.org.cn/models/xiaomile/'
'consistency_models/weight/cd_imagenet64_l2.pt',
num_classes=1000,
sampler='multistep',
seed=42,
training_mode='consistency_distillation',
ts='0,22,39',
data_preprocessor=dict(
type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = ['../_base_/default_runtime.py']

denoiser_config = dict(
type='KarrasDenoiser',
sigma_data=0.5,
sigma_max=80.0,
sigma_min=0.002,
weight_schedule='uniform',
)

unet_config = dict(
type='ConsistencyUNetModel',
in_channels=3,
model_channels=192,
num_res_blocks=3,
dropout=0.0,
channel_mult='',
use_checkpoint=False,
use_fp16=False,
num_head_channels=64,
num_heads=4,
num_heads_upsample=-1,
resblock_updown=True,
use_new_attention_order=False,
use_scale_shift_norm=True)

model = dict(
type='ConsistencyModel',
unet=unet_config,
denoiser=denoiser_config,
attention_resolutions='32,16,8',
batch_size=4,
class_cond=True,
generator='determ',
image_size=64,
learn_sigma=False,
model_path='https://download.openxlab.org.cn/models/xiaomile/'
'consistency_models/weight/cd_imagenet64_l2.pt',
num_classes=1000,
sampler='onestep',
seed=42,
training_mode='consistency_distillation',
ts='',
data_preprocessor=dict(
type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = ['../_base_/default_runtime.py']

denoiser_config = dict(
type='KarrasDenoiser',
sigma_data=0.5,
sigma_max=80.0,
sigma_min=0.002,
weight_schedule='uniform',
)

unet_config = dict(
type='ConsistencyUNetModel',
in_channels=3,
model_channels=256,
num_res_blocks=2,
dropout=0.0,
channel_mult='',
use_checkpoint=False,
use_fp16=False,
num_head_channels=64,
num_heads=4,
num_heads_upsample=-1,
resblock_updown=True,
use_new_attention_order=False,
use_scale_shift_norm=False)

model = dict(
type='ConsistencyModel',
unet=unet_config,
denoiser=denoiser_config,
attention_resolutions='32,16,8',
batch_size=4,
class_cond=False,
generator='determ-indiv',
image_size=256,
learn_sigma=False,
model_path='https://download.openxlab.org.cn/models/xiaomile/'
'consistency_models/weight/ct_bedroom256.pt',
num_classes=1000,
sampler='multistep',
seed=42,
training_mode='consistency_distillation',
ts='0,67,150',
steps=151,
data_preprocessor=dict(
type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) OpenMMLab. All rights reserved.
_base_ = ['../_base_/default_runtime.py']

denoiser_config = dict(
type='KarrasDenoiser',
sigma_data=0.5,
sigma_max=80.0,
sigma_min=0.002,
weight_schedule='uniform',
)

unet_config = dict(
type='ConsistencyUNetModel',
in_channels=3,
model_channels=256,
num_res_blocks=2,
dropout=0.0,
channel_mult='',
use_checkpoint=False,
use_fp16=False,
num_head_channels=64,
num_heads=4,
num_heads_upsample=-1,
resblock_updown=True,
use_new_attention_order=False,
use_scale_shift_norm=False)

model = dict(
type='ConsistencyModel',
unet=unet_config,
denoiser=denoiser_config,
attention_resolutions='32,16,8',
batch_size=4,
class_cond=False,
generator='determ-indiv',
image_size=256,
learn_sigma=False,
model_path='https://download.openxlab.org.cn/models/xiaomile/'
'consistency_models/weight/ct_bedroom256.pt',
num_classes=1000,
sampler='onestep',
seed=42,
training_mode='consistency_distillation',
ts='',
data_preprocessor=dict(
type='DataPreprocessor', mean=[127.5] * 3, std=[127.5] * 3))
Loading
Loading