diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 4326cf1e..737b125b 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -10,6 +10,7 @@ import torch from composer.devices import DeviceGPU from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel +from peft import LoraConfig from torchmetrics import MeanSquaredError from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig @@ -67,6 +68,8 @@ def stable_diffusion_2( fsdp: bool = True, clip_qkv: Optional[float] = None, use_xformers: bool = True, + lora_rank: Optional[int] = None, + lora_alpha: Optional[int] = None, ): """Stable diffusion v2 training setup. @@ -108,6 +111,8 @@ def stable_diffusion_2( fsdp (bool): Whether to use FSDP. Defaults to True. clip_qkv (float, optional): If not None, clip the qkv values to this value. Defaults to None. use_xformers (bool): Whether to use xformers for attention. Defaults to True. + lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None. + lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None. """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -215,6 +220,40 @@ def stable_diffusion_2( mask_pad_tokens=mask_pad_tokens, fsdp=fsdp, ) + if lora_rank is not None: + assert lora_alpha is not None + model.unet.requires_grad_(False) + for param in model.unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights='gaussian', + target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'], + ) + model.unet.add_adapter(unet_lora_config) + model.unet._fsdp_wrap = True + if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None: + for attention in model.unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in model.unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) if is_xformers_installed and use_xformers: @@ -262,6 +301,8 @@ def stable_diffusion_xl( fsdp: bool = True, clip_qkv: Optional[float] = None, use_xformers: bool = True, + lora_rank: Optional[int] = None, + lora_alpha: Optional[int] = None, ): """Stable diffusion 2 training setup + SDXL UNet and VAE. @@ -315,6 +356,8 @@ def stable_diffusion_xl( clip_qkv (float, optional): If not None, clip the qkv values to this value. Improves stability of training. Default: ``None``. use_xformers (bool): Whether to use xformers for attention. Defaults to True. + lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None. + lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None. """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -481,6 +524,40 @@ def stable_diffusion_xl( fsdp=fsdp, sdxl=True, ) + + if lora_rank is not None: + assert lora_alpha is not None + model.unet.requires_grad_(False) + for param in model.unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights='gaussian', + target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'], + ) + model.unet.add_adapter(unet_lora_config) + model.unet._fsdp_wrap = True + if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None: + for attention in model.unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in model.unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) if is_xformers_installed and use_xformers: diff --git a/setup.py b/setup.py index 469d1268..f909592a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ 'diffusers[torch]==0.26.3', 'transformers[torch]==4.38.2', 'huggingface_hub==0.21.2', 'wandb==0.16.3', 'xformers==0.0.23.post1', 'triton==2.1.0', 'torchmetrics[image]==1.3.1', 'lpips==0.1.4', 'clean-fid==0.1.35', 'clip@git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33', 'gradio==4.19.2', - 'datasets==2.19.2' + 'datasets==2.19.2', 'peft==0.12.0' ] extras_require = {}