7
7
from pathlib import Path
8
8
from typing import Optional
9
9
from packaging import version
10
+ import itertools
10
11
11
12
import numpy as np
12
13
import torch
22
23
from torch .utils .data import Dataset
23
24
from torchvision import transforms
24
25
from tqdm .auto import tqdm
25
- from transformers import AutoTokenizer , PretrainedConfig
26
+ from transformers import CLIPTextModel , AutoTokenizer , PretrainedConfig
26
27
27
28
import diffusers
28
29
from diffusers import __version__
33
34
StableDiffusionPipeline ,
34
35
DPMSolverMultistepScheduler ,
35
36
)
36
- from svdiff_pytorch import load_unet_for_svdiff , SCHEDULER_MAPPING
37
+ from svdiff_pytorch import load_unet_for_svdiff , load_text_encoder_for_svdiff , SCHEDULER_MAPPING
37
38
from diffusers .loaders import AttnProcsLayers
38
39
from diffusers .optimization import get_scheduler
39
40
from diffusers .utils import check_min_version , is_wandb_available
@@ -72,32 +73,12 @@ def save_model_card(repo_id: str, base_model=str, prompt=str, repo_folder=None):
72
73
"""
73
74
model_card = f"""
74
75
# SVDiff-pytorch - { repo_id }
75
- These are SVDiff weights for { base_model } . The weights were trained on { prompt } using [DreamBooth](https://dreambooth.github.io/) .
76
+ These are SVDiff weights for { base_model } . The weights were trained on { prompt } .
76
77
"""
77
78
with open (os .path .join (repo_folder , "README.md" ), "w" ) as f :
78
79
f .write (yaml + model_card )
79
80
80
81
81
- def import_model_class_from_model_name_or_path (pretrained_model_name_or_path : str , revision : str ):
82
- text_encoder_config = PretrainedConfig .from_pretrained (
83
- pretrained_model_name_or_path ,
84
- subfolder = "text_encoder" ,
85
- revision = revision ,
86
- )
87
- model_class = text_encoder_config .architectures [0 ]
88
-
89
- if model_class == "CLIPTextModel" :
90
- from transformers import CLIPTextModel
91
-
92
- return CLIPTextModel
93
- elif model_class == "RobertaSeriesModelWithTransformation" :
94
- from diffusers .pipelines .alt_diffusion .modeling_roberta_series import RobertaSeriesModelWithTransformation
95
-
96
- return RobertaSeriesModelWithTransformation
97
- else :
98
- raise ValueError (f"{ model_class } is not supported." )
99
-
100
-
101
82
def parse_args (input_args = None ):
102
83
parser = argparse .ArgumentParser (description = "Simple example of a training script." )
103
84
parser .add_argument (
@@ -271,9 +252,15 @@ def parse_args(input_args=None):
271
252
parser .add_argument (
272
253
"--learning_rate" ,
273
254
type = float ,
274
- default = 5e-4 ,
255
+ default = 1e-3 ,
275
256
help = "Initial learning rate (after the potential warmup period) to use." ,
276
257
)
258
+ parser .add_argument (
259
+ "--learning_rate_1d" ,
260
+ type = float ,
261
+ default = 1e-6 ,
262
+ help = "Initial learning rate (after the potential warmup period) to use for 1-d weights" ,
263
+ )
277
264
parser .add_argument (
278
265
"--scale_lr" ,
279
266
action = "store_true" ,
@@ -380,6 +367,11 @@ def parse_args(input_args=None):
380
367
parser .add_argument (
381
368
"--enable_token_merging" , action = "store_true" , help = "Whether or not to use tomesd on prior generation"
382
369
)
370
+ parser .add_argument (
371
+ "--train_text_encoder" ,
372
+ action = "store_true" ,
373
+ help = "Whether to train spectral shifts of the text encoder. If set, the text encoder should be float32 precision." ,
374
+ )
383
375
if input_args is not None :
384
376
args = parser .parse_args (input_args )
385
377
else :
@@ -594,6 +586,11 @@ def main(args):
594
586
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
595
587
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
596
588
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
589
+ if args .train_text_encoder and args .gradient_accumulation_steps > 1 and accelerator .num_processes > 1 :
590
+ raise ValueError (
591
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
592
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
593
+ )
597
594
# Make one log on every process with the configuration for debugging.
598
595
logging .basicConfig (
599
596
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -700,14 +697,14 @@ def main(args):
700
697
use_fast = False ,
701
698
)
702
699
703
- # import correct text encoder class
704
- text_encoder_cls = import_model_class_from_model_name_or_path (args .pretrained_model_name_or_path , args .revision )
705
-
706
700
# Load scheduler and models
707
701
noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
708
- text_encoder = text_encoder_cls .from_pretrained (
709
- args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
710
- )
702
+ if args .train_text_encoder :
703
+ text_encoder = load_text_encoder_for_svdiff (args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision )
704
+ else :
705
+ text_encoder = CLIPTextModel .from_pretrained (
706
+ args .pretrained_model_name_or_path , subfolder = "text_encoder" , revision = args .revision
707
+ )
711
708
vae = AutoencoderKL .from_pretrained (args .pretrained_model_name_or_path , subfolder = "vae" , revision = args .revision )
712
709
unet = load_unet_for_svdiff (args .pretrained_model_name_or_path , subfolder = "unet" , revision = args .revision , low_cpu_mem_usage = True )
713
710
@@ -716,26 +713,26 @@ def main(args):
716
713
text_encoder .requires_grad_ (False )
717
714
unet .requires_grad_ (False )
718
715
optim_params = []
716
+ optim_params_1d = []
719
717
for n , p in unet .named_parameters ():
720
718
if "delta" in n :
721
719
p .requires_grad = True
722
- optim_params .append (p )
720
+ if "norm" in n :
721
+ optim_params_1d .append (p )
722
+ else :
723
+ optim_params .append (p )
724
+ if args .train_text_encoder :
725
+ for n , p in text_encoder .named_parameters ():
726
+ if "delta" in n :
727
+ p .requires_grad = True
728
+ if "norm" in n :
729
+ optim_params_1d .append (p )
730
+ else :
731
+ optim_params .append (p )
732
+
723
733
total_params = sum (p .numel () for p in optim_params )
724
734
print (f"Number of Trainable Parameters: { total_params * 1.e-6 :.2f} M" )
725
735
726
- # For mixed precision training we cast the text_encoder and vae weights to half-precision
727
- # as these models are only used for inference, keeping weights in full precision is not required.
728
- weight_dtype = torch .float32
729
- if accelerator .mixed_precision == "fp16" :
730
- weight_dtype = torch .float16
731
- elif accelerator .mixed_precision == "bf16" :
732
- weight_dtype = torch .bfloat16
733
-
734
- # Move unet, vae and text_encoder to device and cast to weight_dtype
735
- # unet.to(accelerator.device, dtype=weight_dtype)
736
- vae .to (accelerator .device , dtype = weight_dtype )
737
- text_encoder .to (accelerator .device , dtype = weight_dtype )
738
-
739
736
if args .enable_xformers_memory_efficient_attention :
740
737
if is_xformers_available ():
741
738
import xformers
@@ -751,12 +748,26 @@ def main(args):
751
748
752
749
if args .gradient_checkpointing :
753
750
unet .enable_gradient_checkpointing ()
751
+ if args .train_text_encoder :
752
+ text_encoder .gradient_checkpointing_enable ()
754
753
755
- if args .scale_lr :
756
- args .learning_rate = (
757
- args .learning_rate * args .gradient_accumulation_steps * args .train_batch_size * accelerator .num_processes
754
+ # Check that all trainable models are in full precision
755
+ low_precision_error_string = (
756
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
757
+ " doing mixed precision training. copy of the weights should still be float32."
758
+ )
759
+
760
+ if accelerator .unwrap_model (unet ).dtype != torch .float32 :
761
+ raise ValueError (
762
+ f"Unet loaded as datatype { accelerator .unwrap_model (unet ).dtype } . { low_precision_error_string } "
758
763
)
759
764
765
+ if args .train_text_encoder and accelerator .unwrap_model (text_encoder ).dtype != torch .float32 :
766
+ raise ValueError (
767
+ f"Text encoder loaded as datatype { accelerator .unwrap_model (text_encoder ).dtype } ."
768
+ f" { low_precision_error_string } "
769
+ )
770
+
760
771
# Enable TF32 for faster training on Ampere GPUs,
761
772
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
762
773
if args .allow_tf32 :
@@ -782,7 +793,7 @@ def main(args):
782
793
783
794
# Optimizer creation
784
795
optimizer = optimizer_class (
785
- optim_params ,
796
+ [{ "params" : optim_params }, { "params" : optim_params_1d , "lr" : args . learning_rate_1d }] ,
786
797
lr = args .learning_rate ,
787
798
betas = (args .adam_beta1 , args .adam_beta2 ),
788
799
weight_decay = args .adam_weight_decay ,
@@ -826,9 +837,29 @@ def main(args):
826
837
)
827
838
828
839
# Prepare everything with our `accelerator`.
829
- unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
830
- unet , optimizer , train_dataloader , lr_scheduler
831
- )
840
+ if args .train_text_encoder :
841
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
842
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler
843
+ )
844
+ else :
845
+ unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
846
+ unet , optimizer , train_dataloader , lr_scheduler
847
+ )
848
+
849
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
850
+ # as these models are only used for inference, keeping weights in full precision is not required.
851
+ weight_dtype = torch .float32
852
+ if accelerator .mixed_precision == "fp16" :
853
+ weight_dtype = torch .float16
854
+ elif accelerator .mixed_precision == "bf16" :
855
+ weight_dtype = torch .bfloat16
856
+
857
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
858
+ # unet.to(accelerator.device, dtype=weight_dtype)
859
+ vae .to (accelerator .device , dtype = weight_dtype )
860
+ if not args .train_text_encoder :
861
+ text_encoder .to (accelerator .device , dtype = weight_dtype )
862
+
832
863
833
864
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
834
865
num_update_steps_per_epoch = math .ceil (len (train_dataloader ) / args .gradient_accumulation_steps )
@@ -842,14 +873,27 @@ def main(args):
842
873
if accelerator .is_main_process :
843
874
accelerator .init_trackers ("svdiff-pytorch" , config = vars (args ))
844
875
845
- def save_weights (step ):
876
+ # cache keys to save
877
+ state_dict_keys = [k for k in accelerator .unwrap_model (unet ).state_dict ().keys () if "delta" in k ]
878
+ if args .train_text_encoder :
879
+ state_dict_keys_te = [k for k in accelerator .unwrap_model (text_encoder ).state_dict ().keys () if "delta" in k ]
880
+
881
+ def save_weights (step , save_path = None ):
846
882
# Create the pipeline using using the trained modules and save it.
847
883
if accelerator .is_main_process :
848
- save_path = os .path .join (args .output_dir , f"checkpoint-{ step } " )
884
+ if save_path is None :
885
+ save_path = os .path .join (args .output_dir , f"checkpoint-{ step } " )
849
886
os .makedirs (save_path , exist_ok = True )
850
- unet_model = accelerator .unwrap_model (unet , keep_fp32_wrapper = True )
851
- state_dict = {k : v for k , v in unet_model .state_dict ().items () if "delta" in k }
887
+ state_dict = accelerator .unwrap_model (unet , keep_fp32_wrapper = True ).state_dict ()
888
+ # state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
889
+ state_dict = {k : state_dict [k ] for k in state_dict_keys }
852
890
save_file (state_dict , os .path .join (save_path , "spectral_shifts.safetensors" ))
891
+ if args .train_text_encoder :
892
+ state_dict = accelerator .unwrap_model (text_encoder , keep_fp32_wrapper = True ).state_dict ()
893
+ # state_dict = {k: v for k, v in unet_model.state_dict().items() if "delta" in k}
894
+ state_dict = {k : state_dict [k ] for k in state_dict_keys_te }
895
+ save_file (state_dict , os .path .join (save_path , "spectral_shifts_te.safetensors" ))
896
+
853
897
print (f"[*] Weights saved at { save_path } " )
854
898
855
899
# Train!
@@ -897,6 +941,8 @@ def save_weights(step):
897
941
898
942
for epoch in range (first_epoch , args .num_train_epochs ):
899
943
unet .train ()
944
+ if args .train_text_encoder :
945
+ text_encoder .train ()
900
946
for step , batch in enumerate (train_dataloader ):
901
947
# Skip steps until we reach the resumed step
902
948
if args .resume_from_checkpoint and epoch == first_epoch and step < resume_step :
@@ -952,7 +998,11 @@ def save_weights(step):
952
998
953
999
accelerator .backward (loss )
954
1000
if accelerator .sync_gradients :
955
- params_to_clip = unet .parameters ()
1001
+ params_to_clip = (
1002
+ itertools .chain (unet .parameters (), text_encoder .parameters ())
1003
+ if args .train_text_encoder
1004
+ else unet .parameters ()
1005
+ )
956
1006
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
957
1007
optimizer .step ()
958
1008
lr_scheduler .step ()
@@ -970,7 +1020,7 @@ def save_weights(step):
970
1020
# accelerator.save_state(save_path)
971
1021
# logger.info(f"Saved state to {save_path}")
972
1022
973
- logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ]}
1023
+ logs = {"loss" : loss .detach ().item (), "lr" : lr_scheduler .get_last_lr ()[0 ], "lr_1d" : lr_scheduler . get_last_lr ()[ 1 ] }
974
1024
progress_bar .set_postfix (** logs )
975
1025
accelerator .log (logs , step = global_step )
976
1026
@@ -982,14 +1032,8 @@ def save_weights(step):
982
1032
log_validation (text_encoder , tokenizer , unet , vae , args , accelerator , weight_dtype , epoch )
983
1033
984
1034
accelerator .wait_for_everyone ()
985
- save_weights (global_step )
986
1035
# put the latest checkpoint to output-dir
987
- save_path = args .output_dir
988
- unet_model = accelerator .unwrap_model (unet , keep_fp32_wrapper = True )
989
- state_dict = {k : v for k , v in unet_model .state_dict ().items () if "delta" in k }
990
- save_file (state_dict , os .path .join (save_path , "spectral_shifts.safetensors" ))
991
- print (f"[*] Weights saved at { save_path } " )
992
-
1036
+ save_weights (global_step , save_path = args .output_dir )
993
1037
if accelerator .is_main_process :
994
1038
if args .push_to_hub :
995
1039
save_model_card (
0 commit comments