Skip to content

Commit 9199552

Browse files
authoredApr 12, 2023
Merge pull request #5 from mkshing/v0.2.0
v0.2.0
2 parents 7f47d31 + c645374 commit 9199552

20 files changed

+3727
-386
lines changed
 

‎README.md

+90-12
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@ My summary tweet is found [here](https://twitter.com/mk1stats/status/16428655051
1212
left: LoRA, right: SVDiff
1313

1414

15-
Compared with LoRA, the number of trainable parameters is 0.6 M less parameters and the file size is only <1MB (LoRA: 3.1MB)!!
15+
Compared with LoRA, the number of trainable parameters is 0.5 M less parameters and the file size is only 1.2MB (LoRA: 3.1MB)!!
1616

1717
![kumamon](assets/kumamon.png)
1818

19+
## Updates
20+
### 2023.4.11
21+
- Released v0.2.0 (please see [here](https://github.com/mkshing/svdiff-pytorch/releases/tag/v0.2.0) for the details)
22+
- Add [Single Image Editing](#single-image-editing)
23+
![chair-result](assets/chair-result.png)
24+
<br>"photo of a ~~pink~~ blue chair with black legs"
25+
1926
## Installation
2027
```
2128
$ pip install svdiff-pytorch
@@ -26,9 +33,10 @@ $ git clone https://github.com/mkshing/svdiff-pytorch
2633
$ pip install -r requirements.txt
2734
```
2835

29-
## Training
30-
The following example script is for "Single-Subject Generation", which is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)
36+
## Single-Subject Generation
37+
"Single-Subject Generation" is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)
3138

39+
### Training
3240
According to the paper, the learning rate for SVDiff needs to be 1000 times larger than the lr used for fine-tuning.
3341

3442
```bash
@@ -48,29 +56,32 @@ accelerate launch train_svdiff.py \
4856
--resolution=512 \
4957
--train_batch_size=1 \
5058
--gradient_accumulation_steps=1 \
51-
--learning_rate=5e-3 \
59+
--learning_rate=1e-3 \
60+
--learning_rate_1d=1e-6 \
61+
--train_text_encoder \
5262
--lr_scheduler="constant" \
5363
--lr_warmup_steps=0 \
5464
--num_class_images=200 \
55-
--max_train_steps=800
65+
--max_train_steps=500
5666
```
5767

58-
59-
## Inference
68+
### Inference
6069

6170
```python
6271
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
6372
import torch
6473

65-
from svdiff_pytorch import load_unet_for_svdiff
74+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff
6675

6776
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
68-
spectral_shifts_ckpt = "spectral_shifts.safetensors-path"
69-
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt, subfolder="unet")
77+
spectral_shifts_ckpt_dir = "ckpt-dir-path"
78+
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
79+
text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
7080
# load pipe
7181
pipe = StableDiffusionPipeline.from_pretrained(
7282
pretrained_model_name_or_path,
7383
unet=unet,
84+
text_encoder=text_encoder,
7485
)
7586
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
7687
pipe.to("cuda")
@@ -82,14 +93,14 @@ You can use the following CLI too! Once it's done, you will see `grid.png` for t
8293
```bash
8394
python inference.py \
8495
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
85-
--spectral_shifts_ckpt="spectral_shifts.safetensors-path" \
96+
--spectral_shifts_ckpt="ckpt-dir-path" \
8697
--prompt="A picture of a sks dog in a bucket" \
8798
--scheduler_type="dpm_solver++" \
8899
--num_inference_steps=25 \
89100
--num_images_per_prompt=2
90101
```
91102

92-
## Gradio
103+
### Gradio
93104
You can also try SVDiff-pytorch in a UI with [gradio](https://gradio.app/). This demo supports both training and inference!
94105

95106
[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI)
@@ -103,7 +114,73 @@ $ export HF_TOKEN="YOUR_HF_TOKEN_HERE"
103114
$ python app.py
104115
```
105116

117+
## Single Image Editing
118+
### Training
119+
In Single Image Editing, your instance prompt should be just the description of your input image **without the identifier**.
120+
121+
```bash
122+
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
123+
export INSTANCE_DIR="dir-path-to-input-image"
124+
export CLASS_DIR="path-to-class-images"
125+
export OUTPUT_DIR="path-to-save-model"
126+
127+
accelerate launch train_svdiff.py \
128+
--pretrained_model_name_or_path=$MODEL_NAME \
129+
--instance_data_dir=$INSTANCE_DIR \
130+
--class_data_dir=$CLASS_DIR \
131+
--output_dir=$OUTPUT_DIR \
132+
--with_prior_preservation --prior_loss_weight=1.0 \
133+
--instance_prompt="photo of a pink chair with black legs" \
134+
--class_prompt="photo of a chair" \
135+
--resolution=512 \
136+
--train_batch_size=1 \
137+
--gradient_accumulation_steps=1 \
138+
--learning_rate=1e-3 \
139+
--learning_rate_1d=1e-6 \
140+
--train_text_encoder \
141+
--lr_scheduler="constant" \
142+
--lr_warmup_steps=0 \
143+
--num_class_images=200 \
144+
--max_train_steps=500
145+
```
146+
147+
### Inference
148+
149+
```python
150+
import torch
151+
from PIL import Image
152+
from diffusers import DDIMScheduler
153+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, StableDiffusionPipelineWithDDIMInversion
154+
155+
pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
156+
spectral_shifts_ckpt_dir = "ckpt-dir-path"
157+
image = "path-to-image"
158+
source_prompt = "prompt-for-image"
159+
target_prompt = "prompt-you-want-to-generate"
160+
161+
unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
162+
text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
163+
# load pipe
164+
pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
165+
pretrained_model_name_or_path,
166+
unet=unet,
167+
text_encoder=text_encoder,
168+
)
169+
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
170+
pipe.to("cuda")
171+
172+
# (optional) ddim inversion
173+
# if you don't do it, inv_latents = None
174+
image = Image.open(image).convert("RGB").resize((512, 512))
175+
# in SVDiff, they use guidance scale=1 in ddim inversion
176+
inv_latents = pipe.invert(source_prompt, image=image, guidance_scale=1.0).latents
177+
178+
image = pipe(target_prompt, latents=inv_latents).images[0]
179+
```
180+
181+
106182
## Additional Features
183+
107184
### Spectral Shift Scaling
108185

109186
![scale](assets/scale.png)
@@ -165,6 +242,7 @@ And, add `--enable_tome_merging` to your training arguments!
165242
- [x] Training
166243
- [x] Inference
167244
- [x] Scaling spectral shifts
245+
- [x] Support Single Image Editing
168246
- [ ] Support multiple spectral shifts (Section 3.2)
169247
- [ ] Cut-Mix-Unmix (Section 3.3)
170248
- [ ] SVDiff + LoRA

‎assets/chair-result.png

490 KB
Loading

‎inference.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import argparse
2+
import os
23
from tqdm import tqdm
34
import random
45
import torch
6+
import huggingface_hub
7+
from transformers import CLIPTextModel
58
from diffusers import StableDiffusionPipeline
69
from diffusers.utils import is_xformers_available
7-
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING, image_grid
10+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid
811

912

1013
def parse_args():
@@ -14,7 +17,7 @@ def parse_args():
1417
# diffusers config
1518
parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render")
1619
parser.add_argument("--num_inference_steps", type=int, default=50, help="number of sampling steps")
17-
parser.add_argument("--guidance_scale", type=float, default=1.0, help="unconditional guidance scale")
20+
parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale")
1821
parser.add_argument("--num_images_per_prompt", type=int, default=1, help="number of images per prompt")
1922
parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",)
2023
parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",)
@@ -27,6 +30,33 @@ def parse_args():
2730
return args
2831

2932

33+
def load_text_encoder(pretrained_model_name_or_path, spectral_shifts_ckpt, device, fp16=False):
34+
if os.path.isdir(spectral_shifts_ckpt):
35+
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors")
36+
elif not os.path.exists(spectral_shifts_ckpt):
37+
# download from hub
38+
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
39+
try:
40+
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs)
41+
except huggingface_hub.utils.EntryNotFoundError:
42+
return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
43+
if not os.path.exists(spectral_shifts_ckpt):
44+
return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
45+
text_encoder = load_text_encoder_for_svdiff(
46+
pretrained_model_name_or_path=pretrained_model_name_or_path,
47+
spectral_shifts_ckpt=spectral_shifts_ckpt,
48+
subfolder="text_encoder",
49+
)
50+
# first perform svd and cache
51+
for module in text_encoder.modules():
52+
if hasattr(module, "perform_svd"):
53+
module.perform_svd()
54+
if fp16:
55+
text_encoder = text_encoder.to(device, dtype=torch.float16)
56+
return text_encoder
57+
58+
59+
3060
def main():
3161
args = parse_args()
3262
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -40,10 +70,18 @@ def main():
4070
module.perform_svd()
4171
if args.fp16:
4272
unet = unet.to(device, dtype=torch.float16)
73+
text_encoder = load_text_encoder(
74+
pretrained_model_name_or_path=args.pretrained_model_name_or_path,
75+
spectral_shifts_ckpt=args.spectral_shifts_ckpt,
76+
fp16=args.fp16,
77+
device=device
78+
)
79+
4380
# load pipe
4481
pipe = StableDiffusionPipeline.from_pretrained(
4582
args.pretrained_model_name_or_path,
4683
unet=unet,
84+
text_encoder=text_encoder,
4785
requires_safety_checker=False,
4886
safety_checker=None,
4987
feature_extractor=None,
@@ -67,6 +105,11 @@ def main():
67105
for module in pipe.unet.modules():
68106
if hasattr(module, "set_scale"):
69107
module.set_scale(scale=args.spectral_shifts_scale)
108+
if not isinstance(pipe.text_encoder, CLIPTextModel):
109+
for module in pipe.text_encoder.modules():
110+
if hasattr(module, "set_scale"):
111+
module.set_scale(scale=args.spectral_shifts_scale)
112+
70113
print(f"Set spectral_shifts_scale to {args.spectral_shifts_scale}!")
71114

72115
if args.seed == "random_seed":

‎requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ diffusers==0.14.0
22
accelerate
33
torchvision
44
safetensors
5-
transformers>=4.25.1
5+
transformers>=4.25.1, <=4.27.3
66
ftfy
77
tensorboard
88
Jinja2

‎scripts/svdiff_pytorch.ipynb

+1,681-260
Large diffs are not rendered by default.

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
setup(
55
name="svdiff-pytorch",
6-
version="0.1.1",
6+
version="0.2.0",
77
author="Makoto Shing",
88
url="https://github.com/mkshing/svdiff-pytorch",
99
description="Implementation of 'SVDiff: Compact Parameter Space for Diffusion Fine-Tuning'",

‎svdiff_pytorch/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from svdiff_pytorch.diffusers_models.unet_2d_condition import UNet2DConditionModel as UNet2DConditionModelForSVDiff
2-
from svdiff_pytorch.utils import load_unet_for_svdiff, image_grid, SCHEDULER_MAPPING
2+
from svdiff_pytorch.transformers_models_clip.modeling_clip import CLIPTextModel as CLIPTextModelForSVDiff
3+
from svdiff_pytorch.utils import load_unet_for_svdiff, load_text_encoder_for_svdiff, image_grid, SCHEDULER_MAPPING
4+
from svdiff_pytorch.pipeline_stable_diffusion_ddim_inversion import StableDiffusionPipelineWithDDIMInversion

‎svdiff_pytorch/diffusers_models/attention.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from diffusers.utils.import_utils import is_xformers_available
2222
from svdiff_pytorch.diffusers_models.cross_attention import CrossAttention
2323
from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
24-
from svdiff_pytorch.layers import SVDLinear
24+
from svdiff_pytorch.layers import SVDLinear, SVDGroupNorm, SVDLayerNorm
2525

2626

2727
if is_xformers_available():
@@ -62,7 +62,7 @@ def __init__(
6262

6363
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
6464
self.num_head_size = num_head_channels
65-
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
65+
self.group_norm = SVDGroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
6666

6767
# define q,k,v as linear layers
6868
self.query = SVDLinear(channels, channels)
@@ -252,7 +252,7 @@ def __init__(
252252
elif self.use_ada_layer_norm_zero:
253253
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
254254
else:
255-
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
255+
self.norm1 = SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine)
256256

257257
if cross_attention_dim is not None:
258258
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
@@ -261,13 +261,13 @@ def __init__(
261261
self.norm2 = (
262262
AdaLayerNorm(dim, num_embeds_ada_norm)
263263
if self.use_ada_layer_norm
264-
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
264+
else SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine)
265265
)
266266
else:
267267
self.norm2 = None
268268

269269
# 3. Feed-forward
270-
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
270+
self.norm3 = SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine)
271271

272272
def forward(
273273
self,
@@ -453,7 +453,7 @@ def __init__(self, embedding_dim, num_embeddings):
453453
self.emb = nn.Embedding(num_embeddings, embedding_dim)
454454
self.silu = nn.SiLU()
455455
self.linear = SVDLinear(embedding_dim, embedding_dim * 2)
456-
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
456+
self.norm = SVDLayerNorm(embedding_dim, elementwise_affine=False)
457457

458458
def forward(self, x, timestep):
459459
emb = self.linear(self.silu(self.emb(timestep)))
@@ -474,7 +474,7 @@ def __init__(self, embedding_dim, num_embeddings):
474474

475475
self.silu = nn.SiLU()
476476
self.linear = SVDLinear(embedding_dim, 6 * embedding_dim, bias=True)
477-
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
477+
self.norm = SVDLayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
478478

479479
def forward(self, x, timestep, class_labels, hidden_dtype=None):
480480
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))

‎svdiff_pytorch/diffusers_models/cross_attention.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from diffusers.utils import deprecate, logging
2121
from diffusers.utils.import_utils import is_xformers_available
22-
from svdiff_pytorch.layers import SVDLinear
22+
from svdiff_pytorch.layers import SVDLinear, SVDGroupNorm, SVDLayerNorm
2323

2424

2525
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -80,12 +80,12 @@ def __init__(
8080
self.added_kv_proj_dim = added_kv_proj_dim
8181

8282
if norm_num_groups is not None:
83-
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
83+
self.group_norm = SVDGroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
8484
else:
8585
self.group_norm = None
8686

8787
if cross_attention_norm:
88-
self.norm_cross = nn.LayerNorm(cross_attention_dim)
88+
self.norm_cross = SVDLayerNorm(cross_attention_dim)
8989

9090
self.to_q = SVDLinear(query_dim, inner_dim, bias=bias)
9191
self.to_k = SVDLinear(cross_attention_dim, inner_dim, bias=bias)

‎svdiff_pytorch/diffusers_models/embeddings.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def __init__(
137137
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
138138
)
139139
if layer_norm:
140-
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
140+
self.norm = SVDLayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
141141
else:
142142
self.norm = None
143143

‎svdiff_pytorch/diffusers_models/resnet.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.nn.functional as F
77

88
from svdiff_pytorch.diffusers_models.attention import AdaGroupNorm
9-
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear
9+
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm
1010

1111

1212
class Upsample1D(nn.Module):
@@ -472,7 +472,7 @@ def __init__(
472472
if self.time_embedding_norm == "ada_group":
473473
self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
474474
else:
475-
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
475+
self.norm1 = SVDGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
476476

477477
self.conv1 = SVDConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
478478

@@ -491,7 +491,7 @@ def __init__(
491491
if self.time_embedding_norm == "ada_group":
492492
self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
493493
else:
494-
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
494+
self.norm2 = SVDGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
495495

496496
self.dropout = torch.nn.Dropout(dropout)
497497
conv_2d_out_channels = conv_2d_out_channels or out_channels
@@ -609,7 +609,7 @@ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
609609
super().__init__()
610610

611611
self.conv1d = SVDConv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
612-
self.group_norm = nn.GroupNorm(n_groups, out_channels)
612+
self.group_norm = SVDGroupNorm(n_groups, out_channels)
613613
self.mish = nn.Mish()
614614

615615
def forward(self, x):

‎svdiff_pytorch/diffusers_models/transformer_2d.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from svdiff_pytorch.diffusers_models.attention import BasicTransformerBlock
2525
from diffusers.models.embeddings import PatchEmbed
2626
from diffusers.models.modeling_utils import ModelMixin
27-
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear
27+
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm
2828

2929

3030
@dataclass
@@ -143,7 +143,7 @@ def __init__(
143143
if self.is_input_continuous:
144144
self.in_channels = in_channels
145145

146-
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
146+
self.norm = SVDGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
147147
if use_linear_projection:
148148
self.proj_in = SVDLinear(in_channels, inner_dim)
149149
else:
@@ -205,10 +205,10 @@ def __init__(
205205
else:
206206
self.proj_out = SVDConv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
207207
elif self.is_input_vectorized:
208-
self.norm_out = nn.LayerNorm(inner_dim)
208+
self.norm_out = SVDLayerNorm(inner_dim)
209209
self.out = SVDLinear(inner_dim, self.num_vector_embeds - 1)
210210
elif self.is_input_patches:
211-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
211+
self.norm_out = SVDLayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
212212
self.proj_out_1 = SVDLinear(inner_dim, 2 * inner_dim)
213213
self.proj_out_2 = SVDLinear(inner_dim, patch_size * patch_size * self.out_channels)
214214

‎svdiff_pytorch/diffusers_models/unet_2d_blocks.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from svdiff_pytorch.diffusers_models.dual_transformer_2d import DualTransformer2DModel
2323
from svdiff_pytorch.diffusers_models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D
2424
from svdiff_pytorch.diffusers_models.transformer_2d import Transformer2DModel
25-
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear
25+
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDLayerNorm, SVDGroupNorm
2626

2727

2828
def get_down_block(
@@ -2089,7 +2089,7 @@ def __init__(
20892089
kernel="fir",
20902090
)
20912091
self.skip_conv = SVDConv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
2092-
self.skip_norm = torch.nn.GroupNorm(
2092+
self.skip_norm = SVDGroupNorm(
20932093
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
20942094
)
20952095
self.act = nn.SiLU()
@@ -2186,7 +2186,7 @@ def __init__(
21862186
kernel="fir",
21872187
)
21882188
self.skip_conv = SVDConv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
2189-
self.skip_norm = torch.nn.GroupNorm(
2189+
self.skip_norm = SVDGroupNorm(
21902190
num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
21912191
)
21922192
self.act = nn.SiLU()

‎svdiff_pytorch/diffusers_models/unet_2d_condition.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
get_down_block,
3535
get_up_block,
3636
)
37-
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear
37+
from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm
3838

3939

4040
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -348,7 +348,7 @@ def __init__(
348348

349349
# out
350350
if norm_num_groups is not None:
351-
self.conv_norm_out = nn.GroupNorm(
351+
self.conv_norm_out = SVDGroupNorm(
352352
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
353353
)
354354
self.conv_act = nn.SiLU()

‎svdiff_pytorch/layers.py

+122-16
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@ def __init__(
1717
nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
1818
assert type(kernel_size) is int
1919
weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)')
20-
U, S, Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
21-
self.U = U
22-
self.S = S
23-
self.Vh = Vh
20+
self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
2421
# initialize to 0 for smooth tuning
25-
self.delta = nn.Parameter(torch.zeros_like(S))
22+
self.delta = nn.Parameter(torch.zeros_like(self.S))
2623
self.weight.requires_grad = False
2724
self.done_svd = False
2825
self.scale = scale
@@ -63,12 +60,9 @@ def __init__(
6360
nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
6461
assert type(kernel_size) is int
6562
weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)')
66-
U, S, Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
67-
self.U = U
68-
self.S = S
69-
self.Vh = Vh
63+
self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
7064
# initialize to 0 for smooth tuning
71-
self.delta = nn.Parameter(torch.zeros_like(S))
65+
self.delta = nn.Parameter(torch.zeros_like(self.S))
7266
self.weight.requires_grad = False
7367
self.done_svd = False
7468
self.scale = scale
@@ -107,12 +101,9 @@ def __init__(
107101
**kwargs
108102
):
109103
nn.Linear.__init__(self, in_features, out_features, **kwargs)
110-
U, S, Vh = torch.linalg.svd(self.weight, full_matrices=False)
111-
self.U = U
112-
self.S = S
113-
self.Vh = Vh
104+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
114105
# initialize to 0 for smooth tuning
115-
self.delta = nn.Parameter(torch.zeros_like(S))
106+
self.delta = nn.Parameter(torch.zeros_like(self.S))
116107
self.weight.requires_grad = False
117108
self.done_svd = False
118109
self.scale = scale
@@ -135,4 +126,119 @@ def forward(self, x: torch.Tensor):
135126
# this happens after loading the state dict
136127
self.perform_svd()
137128
weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
138-
return F.linear(x, weight_updated, bias=self.bias)
129+
return F.linear(x, weight_updated, bias=self.bias)
130+
131+
132+
class SVDEmbedding(nn.Embedding):
133+
# LoRA implemented in a dense layer
134+
def __init__(
135+
self,
136+
num_embeddings: int,
137+
embedding_dim: int,
138+
scale: float = 1.0,
139+
**kwargs
140+
):
141+
nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
142+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
143+
# initialize to 0 for smooth tuning
144+
self.delta = nn.Parameter(torch.zeros_like(self.S))
145+
self.weight.requires_grad = False
146+
self.done_svd = False
147+
self.scale = scale
148+
self.reset_parameters()
149+
150+
def set_scale(self, scale: float):
151+
self.scale = scale
152+
153+
def perform_svd(self):
154+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
155+
self.done_svd = True
156+
157+
def reset_parameters(self):
158+
nn.Embedding.reset_parameters(self)
159+
if hasattr(self, 'delta'):
160+
nn.init.zeros_(self.delta)
161+
162+
def forward(self, x: torch.Tensor):
163+
if not self.done_svd:
164+
# this happens after loading the state dict
165+
self.perform_svd()
166+
weight_updated = self.U.to(x.device) @ torch.diag(F.relu(self.S.to(x.device)+self.scale * self.delta)) @ self.Vh.to(x.device)
167+
return F.embedding(x, weight_updated, padding_idx=self.padding_idx, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse)
168+
169+
170+
# 1-D
171+
class SVDLayerNorm(nn.LayerNorm):
172+
def __init__(
173+
self,
174+
normalized_shape: int,
175+
scale: float = 1.0,
176+
**kwargs
177+
):
178+
nn.LayerNorm.__init__(self, normalized_shape=normalized_shape, **kwargs)
179+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
180+
# initialize to 0 for smooth tuning
181+
self.delta = nn.Parameter(torch.zeros_like(self.S))
182+
self.weight.requires_grad = False
183+
self.done_svd = False
184+
self.scale = scale
185+
self.reset_parameters()
186+
187+
def set_scale(self, scale: float):
188+
self.scale = scale
189+
190+
def perform_svd(self):
191+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
192+
self.done_svd = True
193+
194+
def reset_parameters(self):
195+
nn.LayerNorm.reset_parameters(self)
196+
if hasattr(self, 'delta'):
197+
nn.init.zeros_(self.delta)
198+
199+
def forward(self, x: torch.Tensor):
200+
if not self.done_svd:
201+
# this happens after loading the state dict
202+
self.perform_svd()
203+
weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
204+
weight_updated = weight_updated.squeeze(0)
205+
return F.layer_norm(x, normalized_shape=self.normalized_shape, weight=weight_updated, bias=self.bias, eps=self.eps)
206+
207+
208+
class SVDGroupNorm(nn.GroupNorm):
209+
def __init__(
210+
self,
211+
num_groups: int,
212+
num_channels: int,
213+
scale: float = 1.0,
214+
**kwargs
215+
):
216+
nn.GroupNorm.__init__(self, num_groups, num_channels, **kwargs)
217+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
218+
# initialize to 0 for smooth tuning
219+
self.delta = nn.Parameter(torch.zeros_like(self.S))
220+
self.weight.requires_grad = False
221+
self.done_svd = False
222+
self.scale = scale
223+
self.reset_parameters()
224+
225+
def set_scale(self, scale: float):
226+
self.scale = scale
227+
228+
def perform_svd(self):
229+
self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
230+
self.done_svd = True
231+
232+
def reset_parameters(self):
233+
nn.GroupNorm.reset_parameters(self)
234+
if hasattr(self, 'delta'):
235+
nn.init.zeros_(self.delta)
236+
237+
def forward(self, x: torch.Tensor):
238+
if not self.done_svd:
239+
# this happens after loading the state dict
240+
self.perform_svd()
241+
weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
242+
weight_updated = weight_updated.squeeze(0)
243+
return F.group_norm(x, num_groups=self.num_groups, weight=weight_updated, bias=self.bias, eps=self.eps)
244+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
from typing import Any, Callable, Dict, List, Optional, Union
2+
import PIL
3+
import torch
4+
from diffusers import StableDiffusionPipeline, DDIMInverseScheduler
5+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
6+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero import Pix2PixInversionPipelineOutput
7+
8+
9+
class StableDiffusionPipelineWithDDIMInversion(StableDiffusionPipeline):
10+
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker: bool = True):
11+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker)
12+
self.inverse_scheduler = DDIMInverseScheduler.from_config(self.scheduler.config)
13+
# self.register_modules(inverse_scheduler=DDIMInverseScheduler.from_config(self.scheduler.config))
14+
15+
16+
def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
17+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
18+
raise ValueError(
19+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
20+
)
21+
22+
image = image.to(device=device, dtype=dtype)
23+
24+
if isinstance(generator, list) and len(generator) != batch_size:
25+
raise ValueError(
26+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
27+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
28+
)
29+
30+
if isinstance(generator, list):
31+
init_latents = [
32+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
33+
]
34+
init_latents = torch.cat(init_latents, dim=0)
35+
else:
36+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
37+
38+
init_latents = self.vae.config.scaling_factor * init_latents
39+
40+
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
41+
raise ValueError(
42+
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
43+
)
44+
else:
45+
init_latents = torch.cat([init_latents], dim=0)
46+
47+
latents = init_latents
48+
49+
return latents
50+
51+
def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
52+
pred_type = self.inverse_scheduler.config.prediction_type
53+
alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
54+
55+
beta_prod_t = 1 - alpha_prod_t
56+
57+
if pred_type == "epsilon":
58+
return model_output
59+
elif pred_type == "sample":
60+
return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
61+
elif pred_type == "v_prediction":
62+
return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
63+
else:
64+
raise ValueError(
65+
f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
66+
)
67+
68+
def auto_corr_loss(self, hidden_states, generator=None):
69+
batch_size, channel, height, width = hidden_states.shape
70+
if batch_size > 1:
71+
raise ValueError("Only batch_size 1 is supported for now")
72+
73+
hidden_states = hidden_states.squeeze(0)
74+
# hidden_states must be shape [C,H,W] now
75+
reg_loss = 0.0
76+
for i in range(hidden_states.shape[0]):
77+
noise = hidden_states[i][None, None, :, :]
78+
while True:
79+
roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
80+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
81+
reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
82+
83+
if noise.shape[2] <= 8:
84+
break
85+
noise = F.avg_pool2d(noise, kernel_size=2)
86+
return reg_loss
87+
88+
def kl_divergence(self, hidden_states):
89+
mean = hidden_states.mean()
90+
var = hidden_states.var()
91+
return var + mean**2 - 1 - torch.log(var + 1e-7)
92+
93+
94+
# based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py#L1063
95+
@torch.no_grad()
96+
def invert(
97+
self,
98+
prompt: Optional[str] = None,
99+
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
100+
num_inference_steps: int = 50,
101+
guidance_scale: float = 1,
102+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
103+
latents: Optional[torch.FloatTensor] = None,
104+
prompt_embeds: Optional[torch.FloatTensor] = None,
105+
output_type: Optional[str] = "pil",
106+
return_dict: bool = True,
107+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
108+
callback_steps: Optional[int] = 1,
109+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
110+
lambda_auto_corr: float = 20.0,
111+
lambda_kl: float = 20.0,
112+
num_reg_steps: int = 0, # disabled
113+
num_auto_corr_rolls: int = 5,
114+
):
115+
# 1. Define call parameters
116+
if prompt is not None and isinstance(prompt, str):
117+
batch_size = 1
118+
elif prompt is not None and isinstance(prompt, list):
119+
batch_size = len(prompt)
120+
else:
121+
batch_size = prompt_embeds.shape[0]
122+
if cross_attention_kwargs is None:
123+
cross_attention_kwargs = {}
124+
125+
device = self._execution_device
126+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
127+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
128+
# corresponds to doing no classifier free guidance.
129+
do_classifier_free_guidance = guidance_scale > 1.0
130+
131+
# 3. Preprocess image
132+
image = preprocess(image)
133+
134+
# 4. Prepare latent variables
135+
latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
136+
137+
# 5. Encode input prompt
138+
num_images_per_prompt = 1
139+
prompt_embeds = self._encode_prompt(
140+
prompt,
141+
device,
142+
num_images_per_prompt,
143+
do_classifier_free_guidance,
144+
prompt_embeds=prompt_embeds,
145+
)
146+
147+
# 4. Prepare timesteps
148+
self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
149+
timesteps = self.inverse_scheduler.timesteps
150+
151+
# 7. Denoising loop where we obtain the cross-attention maps.
152+
num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
153+
with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
154+
for i, t in enumerate(timesteps[:-1]):
155+
# expand the latents if we are doing classifier free guidance
156+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
157+
latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
158+
159+
# predict the noise residual
160+
noise_pred = self.unet(
161+
latent_model_input,
162+
t,
163+
encoder_hidden_states=prompt_embeds,
164+
cross_attention_kwargs=cross_attention_kwargs,
165+
).sample
166+
167+
# perform guidance
168+
if do_classifier_free_guidance:
169+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
170+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
171+
172+
# regularization of the noise prediction
173+
with torch.enable_grad():
174+
for _ in range(num_reg_steps):
175+
if lambda_auto_corr > 0:
176+
for _ in range(num_auto_corr_rolls):
177+
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
178+
179+
# Derive epsilon from model output before regularizing to IID standard normal
180+
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
181+
182+
l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
183+
l_ac.backward()
184+
185+
grad = var.grad.detach() / num_auto_corr_rolls
186+
noise_pred = noise_pred - lambda_auto_corr * grad
187+
188+
if lambda_kl > 0:
189+
var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
190+
191+
# Derive epsilon from model output before regularizing to IID standard normal
192+
var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
193+
194+
l_kld = self.kl_divergence(var_epsilon)
195+
l_kld.backward()
196+
197+
grad = var.grad.detach()
198+
noise_pred = noise_pred - lambda_kl * grad
199+
200+
noise_pred = noise_pred.detach()
201+
202+
# compute the previous noisy sample x_t -> x_t-1
203+
latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
204+
205+
# call the callback, if provided
206+
if i == len(timesteps) - 1 or (
207+
(i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
208+
):
209+
progress_bar.update()
210+
if callback is not None and i % callback_steps == 0:
211+
callback(i, t, latents)
212+
213+
inverted_latents = latents.detach().clone()
214+
215+
# 8. Post-processing
216+
image = self.decode_latents(latents.detach())
217+
218+
# Offload last model to CPU
219+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
220+
self.final_offload_hook.offload()
221+
222+
# 9. Convert to PIL.
223+
if output_type == "pil":
224+
image = self.numpy_to_pil(image)
225+
226+
if not return_dict:
227+
return (inverted_latents, image)
228+
229+
return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image)
230+
231+
232+
233+
if __name__ == '__main__':
234+
from PIL import Image
235+
from diffusers import DDIMScheduler
236+
model_id = "CompVis/stable-diffusion-v1-4"
237+
input_prompt = "A photo of Barack Obama"
238+
prompt = "A photo of Barack Obama smiling with a big grin"
239+
url = "obama.png" # https://github.com/cccntu/efficient-prompt-to-prompt/blob/main/ddim-inversion.ipynb
240+
241+
pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
242+
model_id,
243+
# make sure to load ddim here
244+
scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),
245+
)
246+
image = Image.open(url).convert("RGB").resize((512, 512))
247+
# in SVDiff, they use guidance scale=1 in ddim inversion
248+
inv_latents = pipe.invert(input_prompt, image=image, guidance_scale=1.0).latents
249+
image = pipe(prompt, latents=inv_latents).images[0]
250+
image.save("out.png")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# all files in this folder were taken from https://github.com/huggingface/transformers/blob/v4.27.3/src/transformers/models/clip/modeling_clip.py
2+
# so, these files follow the LICENSE of transformers

‎svdiff_pytorch/transformers_models_clip/modeling_clip.py

+1,325
Large diffs are not rendered by default.

‎svdiff_pytorch/utils.py

+75-5
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
EulerDiscreteScheduler,
1414
EulerAncestralDiscreteScheduler,
1515
)
16+
from transformers import CLIPTextModel, CLIPTextConfig
1617
from diffusers import UNet2DConditionModel
1718
from safetensors.torch import safe_open
18-
from huggingface_hub import hf_hub_download
19-
from svdiff_pytorch import UNet2DConditionModelForSVDiff
19+
import huggingface_hub
20+
from svdiff_pytorch import UNet2DConditionModelForSVDiff, CLIPTextModelForSVDiff
2021

2122

2223

@@ -38,7 +39,7 @@ def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=Non
3839
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
3940
if len(missing_keys) > 0:
4041
raise ValueError(
41-
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
42+
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
4243
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
4344
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
4445
" those weights or else make sure your checkpoint file is correct."
@@ -57,7 +58,7 @@ def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=Non
5758
elif not os.path.exists(spectral_shifts_ckpt):
5859
# download from hub
5960
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
60-
spectral_shifts_ckpt = hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
61+
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
6162
assert os.path.exists(spectral_shifts_ckpt)
6263

6364
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
@@ -68,7 +69,7 @@ def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=Non
6869
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
6970
else:
7071
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
71-
print(f"Resume from {spectral_shifts_ckpt}")
72+
print(f"Resumed from {spectral_shifts_ckpt}")
7273
if "torch_dtype"in kwargs:
7374
model = model.to(kwargs["torch_dtype"])
7475
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
@@ -80,6 +81,75 @@ def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=Non
8081

8182

8283

84+
def load_text_encoder_for_svdiff(
85+
pretrained_model_name_or_path,
86+
spectral_shifts_ckpt=None,
87+
hf_hub_kwargs=None,
88+
**kwargs
89+
):
90+
"""
91+
https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/models/modeling_utils.py#L541
92+
"""
93+
config = CLIPTextConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
94+
original_model = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
95+
state_dict = original_model.state_dict()
96+
with accelerate.init_empty_weights():
97+
model = CLIPTextModelForSVDiff(config)
98+
# load pre-trained weights
99+
param_device = "cpu"
100+
torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
101+
spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
102+
state_dict.update(spectral_shifts_weights)
103+
# move the params from meta device to cpu
104+
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
105+
if len(missing_keys) > 0:
106+
raise ValueError(
107+
f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
108+
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
109+
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
110+
" those weights or else make sure your checkpoint file is correct."
111+
)
112+
113+
for param_name, param in state_dict.items():
114+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
115+
if accepts_dtype:
116+
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
117+
else:
118+
set_module_tensor_to_device(model, param_name, param_device, value=param)
119+
120+
if spectral_shifts_ckpt:
121+
if os.path.isdir(spectral_shifts_ckpt):
122+
spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors")
123+
elif not os.path.exists(spectral_shifts_ckpt):
124+
# download from hub
125+
hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
126+
try:
127+
spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs)
128+
except huggingface_hub.utils.EntryNotFoundError:
129+
spectral_shifts_ckpt = None
130+
# load state dict only if `spectral_shifts_te.safetensors` exists
131+
if os.path.exists(spectral_shifts_ckpt):
132+
with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
133+
for key in f.keys():
134+
# spectral_shifts_weights[key] = f.get_tensor(key)
135+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
136+
if accepts_dtype:
137+
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
138+
else:
139+
set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
140+
print(f"Resumed from {spectral_shifts_ckpt}")
141+
142+
if "torch_dtype"in kwargs:
143+
model = model.to(kwargs["torch_dtype"])
144+
# model.register_to_config(_name_or_path=pretrained_model_name_or_path)
145+
# Set model in evaluation mode to deactivate DropOut modules by default
146+
model.eval()
147+
del original_model
148+
torch.cuda.empty_cache()
149+
return model
150+
151+
152+
83153
def image_grid(imgs, rows, cols):
84154
assert len(imgs) == rows * cols
85155
w, h = imgs[0].size

‎train_svdiff.py

+108-64
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import Optional
99
from packaging import version
10+
import itertools
1011

1112
import numpy as np
1213
import torch
@@ -22,7 +23,7 @@
2223
from torch.utils.data import Dataset
2324
from torchvision import transforms
2425
from tqdm.auto import tqdm
25-
from transformers import AutoTokenizer, PretrainedConfig
26+
from transformers import CLIPTextModel, AutoTokenizer, PretrainedConfig
2627

2728
import diffusers
2829
from diffusers import __version__
@@ -33,7 +34,7 @@
3334
StableDiffusionPipeline,
3435
DPMSolverMultistepScheduler,
3536
)
36-
from svdiff_pytorch import load_unet_for_svdiff, SCHEDULER_MAPPING
37+
from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING
3738
from diffusers.loaders import AttnProcsLayers
3839
from diffusers.optimization import get_scheduler
3940
from diffusers.utils import check_min_version, is_wandb_available
@@ -72,32 +73,12 @@ def save_model_card(repo_id: str, base_model=str, prompt=str, repo_folder=None):
7273
"""
7374
model_card = f"""
7475
# SVDiff-pytorch - {repo_id}
75-
These are SVDiff weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
76+
These are SVDiff weights for {base_model}. The weights were trained on {prompt}.
7677
"""
7778
with open(os.path.join(repo_folder, "README.md"), "w") as f:
7879
f.write(yaml + model_card)
7980

8081

81-
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
82-
text_encoder_config = PretrainedConfig.from_pretrained(
83-
pretrained_model_name_or_path,
84-
subfolder="text_encoder",
85-
revision=revision,
86-
)
87-
model_class = text_encoder_config.architectures[0]
88-
89-
if model_class == "CLIPTextModel":
90-
from transformers import CLIPTextModel
91-
92-
return CLIPTextModel
93-
elif model_class == "RobertaSeriesModelWithTransformation":
94-
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
95-
96-
return RobertaSeriesModelWithTransformation
97-
else:
98-
raise ValueError(f"{model_class} is not supported.")
99-
100-
10182
def parse_args(input_args=None):
10283
parser = argparse.ArgumentParser(description="Simple example of a training script.")
10384
parser.add_argument(
@@ -271,9 +252,15 @@ def parse_args(input_args=None):
271252
parser.add_argument(
272253
"--learning_rate",
273254
type=float,
274-
default=5e-4,
255+
default=1e-3,
275256
help="Initial learning rate (after the potential warmup period) to use.",
276257
)
258+
parser.add_argument(
259+
"--learning_rate_1d",
260+
type=float,
261+
default=1e-6,
262+
help="Initial learning rate (after the potential warmup period) to use for 1-d weights",
263+
)
277264
parser.add_argument(
278265
"--scale_lr",
279266
action="store_true",
@@ -380,6 +367,11 @@ def parse_args(input_args=None):
380367
parser.add_argument(
381368
"--enable_token_merging", action="store_true", help="Whether or not to use tomesd on prior generation"
382369
)
370+
parser.add_argument(
371+
"--train_text_encoder",
372+
action="store_true",
373+
help="Whether to train spectral shifts of the text encoder. If set, the text encoder should be float32 precision.",
374+
)
383375
if input_args is not None:
384376
args = parser.parse_args(input_args)
385377
else:
@@ -594,6 +586,11 @@ def main(args):
594586
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
595587
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
596588
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
589+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
590+
raise ValueError(
591+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
592+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
593+
)
597594
# Make one log on every process with the configuration for debugging.
598595
logging.basicConfig(
599596
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -700,14 +697,14 @@ def main(args):
700697
use_fast=False,
701698
)
702699

703-
# import correct text encoder class
704-
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
705-
706700
# Load scheduler and models
707701
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
708-
text_encoder = text_encoder_cls.from_pretrained(
709-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
710-
)
702+
if args.train_text_encoder:
703+
text_encoder = load_text_encoder_for_svdiff(args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
704+
else:
705+
text_encoder = CLIPTextModel.from_pretrained(
706+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
707+
)
711708
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
712709
unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=True)
713710

@@ -716,26 +713,26 @@ def main(args):
716713
text_encoder.requires_grad_(False)
717714
unet.requires_grad_(False)
718715
optim_params = []
716+
optim_params_1d = []
719717
for n, p in unet.named_parameters():
720718
if "delta" in n:
721719
p.requires_grad = True
722-
optim_params.append(p)
720+
if "norm" in n:
721+
optim_params_1d.append(p)
722+
else:
723+
optim_params.append(p)
724+
if args.train_text_encoder:
725+
for n, p in text_encoder.named_parameters():
726+
if "delta" in n:
727+
p.requires_grad = True
728+
if "norm" in n:
729+
optim_params_1d.append(p)
730+
else:
731+
optim_params.append(p)
732+
723733
total_params = sum(p.numel() for p in optim_params)
724734
print(f"Number of Trainable Parameters: {total_params * 1.e-6:.2f} M")
725735

726-
# For mixed precision training we cast the text_encoder and vae weights to half-precision
727-
# as these models are only used for inference, keeping weights in full precision is not required.
728-
weight_dtype = torch.float32
729-
if accelerator.mixed_precision == "fp16":
730-
weight_dtype = torch.float16
731-
elif accelerator.mixed_precision == "bf16":
732-
weight_dtype = torch.bfloat16
733-
734-
# Move unet, vae and text_encoder to device and cast to weight_dtype
735-
# unet.to(accelerator.device, dtype=weight_dtype)
736-
vae.to(accelerator.device, dtype=weight_dtype)
737-
text_encoder.to(accelerator.device, dtype=weight_dtype)
738-
739736
if args.enable_xformers_memory_efficient_attention:
740737
if is_xformers_available():
741738
import xformers
@@ -751,12 +748,26 @@ def main(args):
751748

752749
if args.gradient_checkpointing:
753750
unet.enable_gradient_checkpointing()
751+
if args.train_text_encoder:
752+
text_encoder.gradient_checkpointing_enable()
754753

755-
if args.scale_lr:
756-
args.learning_rate = (
757-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
754+
# Check that all trainable models are in full precision
755+
low_precision_error_string = (
756+
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
757+
" doing mixed precision training. copy of the weights should still be float32."
758+
)
759+
760+
if accelerator.unwrap_model(unet).dtype != torch.float32:
761+
raise ValueError(
762+
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
758763
)
759764

765+
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
766+
raise ValueError(
767+
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
768+
f" {low_precision_error_string}"
769+
)
770+
760771
# Enable TF32 for faster training on Ampere GPUs,
761772
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
762773
if args.allow_tf32:
@@ -782,7 +793,7 @@ def main(args):
782793

783794
# Optimizer creation
784795
optimizer = optimizer_class(
785-
optim_params,
796+
[{"params": optim_params}, {"params": optim_params_1d, "lr": args.learning_rate_1d}],
786797
lr=args.learning_rate,
787798
betas=(args.adam_beta1, args.adam_beta2),
788799
weight_decay=args.adam_weight_decay,
@@ -826,9 +837,29 @@ def main(args):
826837
)
827838

828839
# Prepare everything with our `accelerator`.
829-
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
830-
unet, optimizer, train_dataloader, lr_scheduler
831-
)
840+
if args.train_text_encoder:
841+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
842+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
843+
)
844+
else:
845+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
846+
unet, optimizer, train_dataloader, lr_scheduler
847+
)
848+
849+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
850+
# as these models are only used for inference, keeping weights in full precision is not required.
851+
weight_dtype = torch.float32
852+
if accelerator.mixed_precision == "fp16":
853+
weight_dtype = torch.float16
854+
elif accelerator.mixed_precision == "bf16":
855+
weight_dtype = torch.bfloat16
856+
857+
# Move unet, vae and text_encoder to device and cast to weight_dtype
858+
# unet.to(accelerator.device, dtype=weight_dtype)
859+
vae.to(accelerator.device, dtype=weight_dtype)
860+
if not args.train_text_encoder:
861+
text_encoder.to(accelerator.device, dtype=weight_dtype)
862+
832863

833864
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
834865
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -842,14 +873,27 @@ def main(args):
842873
if accelerator.is_main_process:
843874
accelerator.init_trackers("svdiff-pytorch", config=vars(args))
844875

845-
def save_weights(step):
876+
# cache keys to save
877+
state_dict_keys = [k for k in accelerator.unwrap_model(unet).state_dict().keys() if "delta" in k]
878+
if args.train_text_encoder:
879+
state_dict_keys_te = [k for k in accelerator.unwrap_model(text_encoder).state_dict().keys() if "delta" in k]
880+
881+
def save_weights(step, save_path=None):
846882
# Create the pipeline using using the trained modules and save it.
847883
if accelerator.is_main_process:
848-
save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
884+
if save_path is None:
885+
save_path = os.path.join(args.output_dir, f"checkpoint-{step}")
849886
os.makedirs(save_path, exist_ok=True)
850-
unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
851-
state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
887+
state_dict = accelerator.unwrap_model(unet, keep_fp32_wrapper=True).state_dict()
888+
# state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
889+
state_dict = {k: state_dict[k] for k in state_dict_keys}
852890
save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
891+
if args.train_text_encoder:
892+
state_dict = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True).state_dict()
893+
# state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
894+
state_dict = {k: state_dict[k] for k in state_dict_keys_te}
895+
save_file(state_dict, os.path.join(save_path, "spectral_shifts_te.safetensors"))
896+
853897
print(f"[*] Weights saved at {save_path}")
854898

855899
# Train!
@@ -897,6 +941,8 @@ def save_weights(step):
897941

898942
for epoch in range(first_epoch, args.num_train_epochs):
899943
unet.train()
944+
if args.train_text_encoder:
945+
text_encoder.train()
900946
for step, batch in enumerate(train_dataloader):
901947
# Skip steps until we reach the resumed step
902948
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -952,7 +998,11 @@ def save_weights(step):
952998

953999
accelerator.backward(loss)
9541000
if accelerator.sync_gradients:
955-
params_to_clip = unet.parameters()
1001+
params_to_clip = (
1002+
itertools.chain(unet.parameters(), text_encoder.parameters())
1003+
if args.train_text_encoder
1004+
else unet.parameters()
1005+
)
9561006
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
9571007
optimizer.step()
9581008
lr_scheduler.step()
@@ -970,7 +1020,7 @@ def save_weights(step):
9701020
# accelerator.save_state(save_path)
9711021
# logger.info(f"Saved state to {save_path}")
9721022

973-
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1023+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "lr_1d": lr_scheduler.get_last_lr()[1]}
9741024
progress_bar.set_postfix(**logs)
9751025
accelerator.log(logs, step=global_step)
9761026

@@ -982,14 +1032,8 @@ def save_weights(step):
9821032
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
9831033

9841034
accelerator.wait_for_everyone()
985-
save_weights(global_step)
9861035
# put the latest checkpoint to output-dir
987-
save_path = args.output_dir
988-
unet_model = accelerator.unwrap_model(unet, keep_fp32_wrapper=True)
989-
state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
990-
save_file(state_dict, os.path.join(save_path, "spectral_shifts.safetensors"))
991-
print(f"[*] Weights saved at {save_path}")
992-
1036+
save_weights(global_step, save_path=args.output_dir)
9931037
if accelerator.is_main_process:
9941038
if args.push_to_hub:
9951039
save_model_card(

0 commit comments

Comments
 (0)
Please sign in to comment.