From a742e9d9506d9b3a2b68497c1f74c62468604f58 Mon Sep 17 00:00:00 2001 From: vipermu Date: Fri, 21 Apr 2023 15:24:27 -0700 Subject: [PATCH 01/18] feat: add controlnet + img2img strengths --- .../pipeline_stable_diffusion_controlnet.py | 151 +++++++++++++++--- 1 file changed, 129 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 322f2232fc8a..b8bb2b2cccb2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -683,7 +683,15 @@ def prepare_image( return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -701,6 +709,61 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents_img2img(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if image is not None: + image = image.to(device=device, dtype=dtype) + + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + else: + latents = latents.to(device) + latents = latents * self.scheduler.init_noise_sigma + + + # scale the initial noise by the standard deviation required by the scheduler + # NOTE: not sure if the following line is necessary in for img2img + # latents = latents * self.scheduler.init_noise_sigma + return latents + def _default_height_width(self, height, width, image): # NOTE: It is possible that a list of images have different # dimensions for each image, so just checking the first image @@ -762,6 +825,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, + img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + img2img_strength: float = 1.0, + controlnet_strength: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -891,6 +957,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) + # 4. Prepare image if isinstance(self.controlnet, ControlNetModel): image = self.prepare_image( @@ -928,20 +995,52 @@ def __call__( # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - + # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + + if img2img_strength > 0.0: + img2img_image = self.prepare_image( + image=img2img_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + img2img_image = 2.0 * img2img_image - 1.0 + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1 - img2img_strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + latents = self.prepare_latents_img2img( + img2img_image[0][None, :], + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + else: + timesteps = self.scheduler.timesteps + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -981,14 +1080,22 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - ).sample + if t <= controlnet_strength * 1000: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: @@ -1037,4 +1144,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file From 404a51f74837443e4bdc20d793f41f7121d4b3a6 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Mon, 19 Jun 2023 23:23:05 +0000 Subject: [PATCH 02/18] wip: modified img2img and inpaint pipeline --- .../stable_diffusion_controlnet_inpaint.py | 19 +- inpaint_test.py | 216 ++++++++++++++++++ .../controlnet/pipeline_controlnet_inpaint.py | 40 +++- .../pipeline_stable_diffusion_img2img.py | 55 ++++- 4 files changed, 322 insertions(+), 8 deletions(-) create mode 100644 inpaint_test.py diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py index aae199f91b9e..ff9ea01c50e3 100644 --- a/examples/community/stable_diffusion_controlnet_inpaint.py +++ b/examples/community/stable_diffusion_controlnet_inpaint.py @@ -4,11 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch import torch.nn.functional as F -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +import PIL.Image from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel @@ -20,6 +19,7 @@ randn_tensor, replace_example_docstring, ) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -842,6 +842,7 @@ def __call__( callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + latent_mask = None, ): r""" Function invoked when calling the pipeline for generation. @@ -1094,7 +1095,18 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + from PIL import Image + latent_mask = Image.open("/home/erwann/diffusers/examples/community/soft_mask2.png") + if latent_mask is None: + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + else: + print("masking latent update") + unmasked_latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latent_mask_np = np.array(latent_mask.resize((unmasked_latents.shape))) + latents = unmasked_latents * latent_mask_np + latents * (1 - latent_mask_np) + + + # image = (image * mask_image + init_image * (1 - mask_image)).astype(np.uint8) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -1136,3 +1148,4 @@ def __call__( return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/inpaint_test.py b/inpaint_test.py new file mode 100644 index 000000000000..050b12db27d3 --- /dev/null +++ b/inpaint_test.py @@ -0,0 +1,216 @@ +# !pip install transformers accelerate +import os + +import numpy as np +import torch + +from diffusers import ( + ControlNetModel, + DDIMScheduler, + EulerAncestralDiscreteScheduler, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionImg2ImgPipeline, +) +from diffusers.utils import load_image +from mask_utils import create_gradient, expand_image, load_image +from masked2 import StableDiffusionMaskedImg2ImgPipeline + + +init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" +) +init_image = init_image.resize((512, 512)) + +# generator = torch.Generator(device="cpu").manual_seed(1) + +mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" +) +mask_image = mask_image.resize((512, 512)) + + +img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" +mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + +init_image = load_image(img_url) +mask_image = load_image(mask_url) + + +img_path = "/home/erwann/diffusers/examples/community/new_image.png" +# mask_path = "/home/erwann/diffusers/examples/community/hard_mask_5.png" +mask_path = "/home/erwann/diffusers/examples/community/mask_image.png" +init_image = load_image(img_path) +mask_image = load_image(mask_path) +# mask_image.save("mask.png") + + +# new_width = 480 +# new_height = new_width * init_image.height / init_image.width +# new_height = 640 +# init_image = init_image.resize((new_width, int(new_height))) + +# mask_image = mask_image.resize(init_image.size) +# mask_image = mask_image.resize((512, 512)) + +def make_inpaint_condition(image, image_mask): + image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + image[image_mask > 0.001] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + +control_image = make_inpaint_condition(init_image, mask_image) + +controlnet = ControlNetModel.from_pretrained( + "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16 +) + +# pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( +# "/home/erwann/diffusers/examples/community/realistic_vision", controlnet=controlnet, torch_dtype=torch.float16 +# ) +from custom_inpaint_pipeline import StableDiffusionMaskedLatentControlNetInpaintPipeline + + +pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + "/home/erwann/diffusers/examples/community/realistic_vision", controlnet=controlnet, torch_dtype=torch.float16 +) + +# pipe = StableDiffusionMaskedLatentControlNetInpaintPipeline.from_pretrained( +# "/home/erwann/diffusers/examples/community/realistic_vision", controlnet=controlnet, torch_dtype=torch.float16 +# ) +pipe = StableDiffusionMaskedLatentControlNetInpaintPipeline( + pipe.vae, pipe.text_encoder, pipe.tokenizer, pipe.unet, pipe.controlnet, pipe.scheduler, None, None, +) + + +# pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( +# "/home/erwann/diffusers/examples/community/deliberate", controlnet=controlnet, torch_dtype=torch.float16 +# ) +# pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + # "/home/erwann/generation-service/safetensor-models/sd1.5", controlnet=controlnet, torch_dtype=torch.float16 +# ) +# generator = None +# speed up diffusion process with faster scheduler and memory optimization +# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) + +pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) + + +from diffusers import DPMSolverMultistepScheduler + + +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) + + + + +# init_image = load_image("/home/erwann/diffusers/examples/community/castle.png") +init_image = load_image("/home/erwann/diffusers/examples/community/bmw.png") +init_image = init_image.resize((512, 512)) + + +extended_image, mask_image = expand_image(init_image, expand_x=0, expand_y=-256) +print("Image size after extending, " + str(extended_image.size)) +control_image = make_inpaint_condition(extended_image, mask_image) +blend_mask = create_gradient(mask_image, x=None, y=-256, offset=200) + +extended_image.save("extended_image.png") +mask_image.save("mask_image.png") +blend_mask.save("blend_mask.png") +# vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").half() +# pipe.vae = vae + +pipe.enable_model_cpu_offload() +pipe.enable_xformers_memory_efficient_attention() + +generator = None +# generator = torch.Generator().manual_seed(456) +generator = torch.Generator().manual_seed(123) +# generate image +pipe.safety_checker = None +prompt= "bmw drifting, pink smoke" +images = pipe( + prompt, + num_inference_steps=25, + generator=generator, + guidance_scale=6.0, + negative_prompt="deformed iris, deformed pupils, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + # eta=1.0, + # eta=1.0, + # soft_mask=blend_mask, + width=extended_image.width, + height=extended_image.height, + image=extended_image, + mask_image=blend_mask, + control_image=control_image, + num_images_per_prompt=4, + controlnet_conditioning_scale=1., + guess_mode=True, +).images + + +folder = "_".join(prompt.split(" ")) +folder = "no_prompt" if len(folder) == 0 else folder +os.makedirs(folder, exist_ok=True) +print("Saving to ", folder) + +for i, image in enumerate(images): + image.save(os.path.join(folder, f"2_extend_{i}.png")) + + +#best config .35 / 20 steps + +# img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("/home/erwann/generation-service/safetensor-models/real", safety_checker=None) +# img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("/home/erwann/generation-service/safetensor-models/realistic_vision", safety_checker=None) +img2img_pipe = StableDiffusionMaskedImg2ImgPipeline.from_pretrained("/home/erwann/generation-service/safetensor-models/realistic_vision", safety_checker=None) + +print("Scheduler") +print(img2img_pipe.scheduler) + + +# img2imgpipe = StableDiffusionImg2ImgPipeline( +# vae=pipe.vae, +# text_encoder=pipe.text_encoder, +# tokenizer=pipe.tokenizer, +# unet=pipe.unet, +# scheduler=pipe.scheduler, +# safety_checker=None, +# feature_extractor=pipe.feature_extractor, +# ) + + +img2img_pipe = img2img_pipe.to("cuda") +img2img_pipe.enable_attention_slicing() +img2img_pipe.enable_xformers_memory_efficient_attention() + +# soft_mask_pil = Image.open("/home/erwann/diffusers/examples/community/soft_mask_5.png") + +# img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained("/home/erwann/generation-service/safetensor-models/real", safety_checker=None) +from PIL import Image + + +for i, image in enumerate(images): + final_image = img2img_pipe( + prompt, + image=image, + mask_image=blend_mask, + strength=0.350, + num_inference_steps=19, + generator=generator, + ).images[0] + final_image.save(os.path.join(folder, f"img2img_{i}_real_cfg8_9.png")) + # plt.imshow(final_image) + # plt.show() + +# import matplotlib.pyplot as plt +# from PIL import Image + + + +# plt.imshow(image) +# plt.show() +# plt.show() diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 165e2d88dca6..a4989fde4fab 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -20,9 +20,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch import torch.nn.functional as F + +import PIL.Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...image_processor import VaeImageProcessor @@ -901,6 +902,33 @@ def _default_height_width(self, height, width, image): return height, width # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents + def resize_mask( + self, mask, dtype=torch.float16, + ): + height = mask.height + width = mask.width + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + # mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + # mask[mask < 0.5] = 0 + # mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device="cuda", dtype=dtype) + return mask + + + + def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance ): @@ -1304,6 +1332,9 @@ def __call__( if num_channels_unet == 4: init_latents_proper = image_latents[:1] init_mask = mask[:1] + from PIL import Image + + # soft_mask = self.resize_mask(soft_mask_pil,) if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] @@ -1311,7 +1342,12 @@ def __call__( init_latents_proper, noise, torch.tensor([noise_timestep]) ) - latents = (1 - init_mask) * init_latents_proper + init_mask * latents + soft_mask = None + if soft_mask is None: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + else: + print("Using soft mask in controlnet inpaint") + latents = (1 - soft_mask) * init_latents_proper + soft_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e9e91b646ed5..070fdc65ca7f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -17,8 +17,9 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import PIL import torch + +import PIL from packaging import version from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer @@ -602,10 +603,35 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents + unnoised_latents = init_latents.clone() init_latents = self.scheduler.add_noise(init_latents, noise, timestep) latents = init_latents - return latents + # return latents + return latents, noise, unnoised_latents + + def resize_mask( + self, mask, dtype=torch.float16, + ): + height = mask.height + width = mask.width + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + # mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask = torch.from_numpy(mask) + + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device="cuda", dtype=dtype) + return mask + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -634,6 +660,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + mask=None, ): r""" Function invoked when calling the pipeline for generation. @@ -745,7 +772,7 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables - latents = self.prepare_latents( + latents, noise, init_latents = self.prepare_latents( image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator ) @@ -777,6 +804,28 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + init_latents_proper = init_latents[:1] + from PIL import Image + soft_mask = None + + print("old******************************") + soft_mask_pil = Image.open("/home/erwann/diffusers/examples/community/soft_mask_5.png") + # soft_mask = None + soft_mask = self.resize_mask(soft_mask_pil,) + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + if soft_mask is None: + print("no soft mask") + # latents = (1 - init_mask) * init_latents_proper + init_mask * latents + else: + print("using soft mask") + latents = (1 - soft_mask) * init_latents_proper + soft_mask * latents + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() From 1ea612ccf1531debebcdc4cd6e7a63495e3cce1d Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Mon, 19 Jun 2023 23:36:51 +0000 Subject: [PATCH 03/18] fix import for controlnet pipeline --- .../controlnet/pipeline_controlnet.py | 458 ++++++++++++------ .../pipeline_stable_diffusion_controlnet.py | 7 +- 2 files changed, 303 insertions(+), 162 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 2a86ee0dfe1e..b2a747becd7f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -15,31 +15,31 @@ import inspect import os -import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch -import torch.nn.functional as F +from torch import nn + +import PIL.Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.controlnet import ControlNetOutput +from ...models.modeling_utils import ModelMixin +from ...pipelines.stable_diffusion import StableDiffusionPipelineOutput +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, - is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion import StableDiffusionPipelineOutput -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -91,7 +91,66 @@ """ -class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states, + image, + scale, + class_labels, + timestep_cond, + attention_mask, + cross_attention_kwargs, + guess_mode, + return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + +class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -171,10 +230,6 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -195,24 +250,6 @@ def disable_vae_slicing(self): """ self.vae.disable_slicing() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): - r""" - Enable tiled VAE decoding. - - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in - several steps. This is useful to save a large amount of memory and to allow the processing of larger images. - """ - self.vae.enable_tiling() - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() - def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, @@ -291,7 +328,6 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -316,14 +352,7 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -380,7 +409,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): + elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -438,28 +467,19 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) + else: + has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -487,12 +507,17 @@ def check_inputs( self, prompt, image, + height, + width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -537,20 +562,9 @@ def check_inputs( ) # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): + if isinstance(self.controlnet, ControlNetModel): self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): + elif isinstance(self.controlnet, MultiControlNetModel): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") @@ -569,18 +583,10 @@ def check_inputs( assert False # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): + if isinstance(self.controlnet, ControlNetModel): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): + elif isinstance(self.controlnet, MultiControlNetModel): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") @@ -597,26 +603,21 @@ def check_inputs( def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" ) if image_is_pil: image_batch_size = 1 - else: + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): @@ -643,7 +644,29 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + image_batch_size = image.shape[0] if image_batch_size == 1: @@ -661,7 +684,15 @@ def prepare_image( return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -679,6 +710,86 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents_img2img(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if image is not None: + image = image.to(device=device, dtype=dtype) + + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + else: + latents = latents.to(device) + latents = latents * self.scheduler.init_noise_sigma + + + # scale the initial noise by the standard deviation required by the scheduler + # NOTE: not sure if the following line is necessary in for img2img + # latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + # override DiffusionPipeline def save_pretrained( self, @@ -696,14 +807,7 @@ def save_pretrained( def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -722,6 +826,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, + img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + img2img_strength: float = 1.0, + controlnet_strength: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -730,8 +837,8 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If @@ -807,11 +914,15 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, + height, + width, callback_steps, negative_prompt, prompt_embeds, @@ -833,22 +944,10 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) prompt_embeds = self._encode_prompt( prompt, device, @@ -857,11 +956,11 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, ) + # 4. Prepare image - if isinstance(controlnet, ControlNetModel): + if isinstance(self.controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, @@ -869,12 +968,11 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) - height, width = image.shape[-2:] - elif isinstance(controlnet, MultiControlNetModel): + elif isinstance(self.controlnet, MultiControlNetModel): images = [] for image_ in image: @@ -885,7 +983,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -893,26 +991,57 @@ def __call__( images.append(image_) image = images - height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - + # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + + if img2img_strength > 0.0: + img2img_image = self.prepare_image( + image=img2img_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + img2img_image = 2.0 * img2img_image - 1.0 + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1 - img2img_strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + latents = self.prepare_latents_img2img( + img2img_image[0][None, :], + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + else: + timesteps = self.scheduler.timesteps + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -928,15 +1057,14 @@ def __call__( # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: - control_model_input = latent_model_input + controlnet_latent_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, + controlnet_latent_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, @@ -953,15 +1081,22 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] + if t <= controlnet_strength * 1000: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: @@ -969,7 +1104,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -984,19 +1119,24 @@ def __call__( self.controlnet.to("cpu") torch.cuda.empty_cache() - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: + if output_type == "latent": image = latents has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + # 8. Post-processing + image = self.decode_latents(latents) - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -1005,4 +1145,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index b8bb2b2cccb2..4cf4084eb62c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -18,9 +18,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch from torch import nn + +import PIL.Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin @@ -37,8 +38,8 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from . import StableDiffusionPipelineOutput -from .safety_checker import StableDiffusionSafetyChecker +# from . import StableDiffusionPipelineOutput +# from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 2981accd00adb55b2f018fe9eaada9c4d9984c8b Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Mon, 19 Jun 2023 23:46:30 +0000 Subject: [PATCH 04/18] merge upstream and preserve img2img controlnet pipeline modificatiosn --- controlnet_inpaint.py | 197 +++++++++++++++ custom_inpaint_pipeline.py | 491 +++++++++++++++++++++++++++++++++++++ mask_utils.py | 110 +++++++++ masked2.py | 353 ++++++++++++++++++++++++++ masked_img2img.py | 222 +++++++++++++++++ 5 files changed, 1373 insertions(+) create mode 100644 controlnet_inpaint.py create mode 100644 custom_inpaint_pipeline.py create mode 100644 mask_utils.py create mode 100644 masked2.py create mode 100644 masked_img2img.py diff --git a/controlnet_inpaint.py b/controlnet_inpaint.py new file mode 100644 index 000000000000..ec7e1c06189a --- /dev/null +++ b/controlnet_inpaint.py @@ -0,0 +1,197 @@ +# !pip install transformers accelerate +import os +from typing import List, Optional, Union + +import numpy as np +import torch + +import PIL +from contexttimer import Timer +from diffusers import ( + ControlNetModel, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + StableDiffusionControlNetInpaintPipeline, + StableDiffusionImg2ImgPipeline, +) +from diffusers.utils import load_image +from mask_utils import create_gradient, expand_image, make_inpaint_condition + + +def controlnet_inpaint( + self, + prompt, + image: Union[str, PIL.Image], + mask: Union[str, PIL.Image], + output_type="pil", + pipe=None, + num_inference_steps: int = 25, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + seed=None, +): + """ + Takes image and mask (can be PIL images, a local path, or a URL) + """ + init_image = load_image(image) + mask_image = load_image(mask) + control_image = make_inpaint_condition(init_image, mask_image) + + with Timer() as t: + if pipe is None: + # pipe = sd.get_pipeline() + raise NotImplementedError("Need to pass in a pipeline for now") + t_load = t.elapsed + + if seed is not None: + generator = torch.Generator().manual_seed(seed) + + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) + with Timer() as t: + images = pipe( + prompt, + num_inference_steps=num_inference_steps, + generator=generator, + guidance_scale=guidance_scale, + negative_prompt="deformed iris, deformed pupils, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", + width=init_image.width, + height=init_image.height, + image=init_image, + mask_image=mask_image, + control_image=control_image, + num_images_per_prompt=4, + output_type=output_type, + ).images + t_inference = t.elapsed + + return { + "images": images, + "performance": { + "t_load": t_load, + "t_inference": t_inference, + } + } + +def controlnet_extend( + self, + prompt: Union[str, List[str]], + init_image: Union[torch.FloatTensor, PIL.Image.Image], + model_id: str = "realistic_vision", + seed: Optional[int] = None, + expand_offset_x: int = 0, + expand_offset_y: int = 0, + img2img_strength: float = 0.35, + img2img_steps: int = 15, + mask_offset: float = 40, + num_inference_steps: int = 25, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + output_type: Optional[str] = "pil", + # pipe +): + + with Timer() as t: + controlnet_inpaint_pipe = self.model_loader.get_pipeline(model_id, StableDiffusionControlNetInpaintPipeline) + t_load = t.elapsed + + with Timer() as t: + extended_image, mask_img = expand_image(init_image, expand_x=expand_offset_x, expand_y=expand_offset_y) + print("Image size after extending, " + str(extended_image.size)) + + blend_mask = create_gradient(mask_img, x=expand_offset_x, y=expand_offset_y, offset=mask_offset) + t_preprocess = t.elapsed + + inpaint_results = controlnet_inpaint( + prompt, + extended_image, + mask_img, + output_type="pil", + pipe=controlnet_inpaint_pipe, + seed=seed, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + ) + inpainted_images = inpaint_results["images"] + performance = inpaint_results["performance"] + + + img2img_pipe = StableDiffusionImg2ImgPipeline( + vae=controlnet_inpaint_pipe.vae, + text_encoder=controlnet_inpaint_pipe.text_encoder, + tokenizer=controlnet_inpaint_pipe.tokenizer, + unet=controlnet_inpaint_pipe.unet, + scheduler=controlnet_inpaint_pipe.scheduler, + safety_checker=None, + feature_extractor=controlnet_inpaint_pipe.feature_extractor, + ) + img2img_pipe = img2img_pipe.to("cuda") + + + # uses masked img2img to homogenize the inpainted zone and the original image, making them look more natural and reduce border artefacts + generator = torch.Generator().manual_seed(seed) if seed is not None else None + with Timer() as t: + final_images = img2img_pipe( + prompt, + # negative_prompt=negative_prompt, + negative_prompt=None, + image=inpainted_images, + mask_image=blend_mask, + strength=img2img_strength, + num_inference_steps=img2img_steps, + generator=generator, + ).images + t_img2img = t.elapsed + + if "t_load" in performance: + performance["t_load"] = performance["t_load"] + t_load + performance["t_preprocess"] = t_preprocess + performance["t_img2img"] = t_img2img + performance["t_inpaint"] = performance["t_inference"] + performance["t_inference"] = sum([performance["t_inference"], t_img2img]) + + for i, image in enumerate(final_images): + image.save("final_image_" + str(i) + ".png") + return { + "images": final_images, + "performance": performance, + } + + +if __name__ == "__main__": + init_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" + ) + init_image = init_image.resize((512, 512)) + + # generator = torch.Generator(device="cpu").manual_seed(1) + + mask_image = load_image( + "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" + ) + mask_image = mask_image.resize((512, 512)) + + img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + + init_image = load_image(img_url) + mask_image = load_image(mask_url) + + img_path = "/home/erwann/diffusers/examples/community/new_image.png" + # mask_path = "/home/erwann/diffusers/examples/community/hard_mask_5.png" + mask_path = "/home/erwann/diffusers/examples/community/mask_image.png" + init_image = load_image(img_path) + mask_image = load_image(mask_path) + # mask_image.save("mask.png") + + # new_width = 480 + # new_height = new_width * init_image.height / init_image.width + # new_height = 640 + # init_image = init_image.resize((new_width, int(new_height))) + + # mask_image = mask_image.resize(init_image.size) + # mask_image = mask_image.resize((512, 512)) diff --git a/custom_inpaint_pipeline.py b/custom_inpaint_pipeline.py new file mode 100644 index 000000000000..45652477db34 --- /dev/null +++ b/custom_inpaint_pipeline.py @@ -0,0 +1,491 @@ +import inspect +import os +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F + +import PIL.Image +from diffusers import StableDiffusionControlNetInpaintPipeline +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + load_image, + logging, + randn_tensor, + replace_example_docstring, +) +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + + +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + + + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + if mask is None: + raise ValueError("`mask_image` input cannot be undefined.") + + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") + + # Batch single image + if image.ndim == 3: + assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" + assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" + assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + # resize all images w.r.t passed height an width + image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image + +class StableDiffusionMaskedLatentControlNetInpaintPipeline(StableDiffusionControlNetInpaintPipeline): + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet, + scheduler: KarrasDiffusionSchedulers, + feature_extractor: CLIPImageProcessor, + safety_checker=None, + requires_safety_checker: bool = False, + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + controlnet, + scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker = False, + ) + + def resize_mask( + self, mask, dtype=torch.float16, + ): + height = mask.height + width = mask.width + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + # mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] + mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask = torch.from_numpy(mask) + + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device="cuda", dtype=dtype) + print("mask unique values" , torch.unique(mask)) + return mask + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.Tensor, PIL.Image.Image] = None, + mask_image: Union[torch.Tensor, PIL.Image.Image] = None, + soft_mask=None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + strength: float = 1.0, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 0.5, + guess_mode: bool = False, + ): + height, width = self._default_height_width(height, width, image) + + if soft_mask is not None: + soft_mask_pil = load_image(soft_mask) + soft_mask = self.resize_mask(soft_mask_pil,) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + control_image = self.prepare_control_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + elif isinstance(controlnet, MultiControlNetModel): + control_images = [] + + for control_image_ in control_image: + control_image_ = self.prepare_control_image( + image=control_image_, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + + control_images.append(control_image_) + + control_image = control_images + else: + assert False + + # 4. Preprocess mask and image - resizes image and mask w.r.t height and width + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1.0 + + # 6. Prepare latent variables + num_channels_latents = self.vae.config.latent_channels + num_channels_unet = self.unet.config.in_channels + return_image_latents = num_channels_unet == 4 + latents_outputs = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + return_noise=True, + return_image_latents=return_image_latents, + ) + + if return_image_latents: + latents, noise, image_latents = latents_outputs + else: + latents, noise = latents_outputs + + # 7. Prepare mask latent variables + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + if num_channels_unet == 9: + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if num_channels_unet == 4: + init_latents_proper = image_latents[:1] + init_mask = mask[:1] + from PIL import Image + + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + if soft_mask is None: + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + else: + print("Using soft mask in controlnet inpaint") + latents = (1 - soft_mask) * init_latents_proper + soft_mask * latents + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file diff --git a/mask_utils.py b/mask_utils.py new file mode 100644 index 000000000000..668d8772789a --- /dev/null +++ b/mask_utils.py @@ -0,0 +1,110 @@ +import numpy as np +import torch + +from diffusers.utils import load_image +from PIL import Image, ImageDraw, ImageFilter + + +def make_inpaint_condition(image, image_mask): + image = np.array(image.convert("RGB")).astype(np.float32) / 255.0 + image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0 + + assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size" + image[image_mask > 0.001] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + +def create_gradient(image, y=None, x=None, offset=40): + """ + Takes a binary mask (white = area to be inpainted, black = area to be kept from original image) and creates a gradient at the border of the mask. The gradient adds a white to black gradient that extends into the original black area. + + This ensures that the inpainted area is not a hard border, but a smooth transition from the inpainted area to the original image. + + Used to blend together latents in MaskedImg2ImgPipeline + """ + if y is None and x is None: + raise ValueError("Either y or x must be specified") + draw = ImageDraw.Draw(image) + if y and x: + raise ValueError("Only one of y or x must be specified (for now)") + + sign = 1 + if offset < 0: + sign = -1 + + offset = abs(offset) + if y is not None: + if y > 0: + y = image.height - y + if offset > 0: + sign = -1 + else: + y = abs(y) + for i in range(abs(offset)): + color = abs(255 - int(255 * (i / abs(offset)))) # calculate grayscale color + i *= sign + draw.line([(0, y+i), (image.width, y+i)], fill=(color, color, color)) + if x is not None: + if x > 0: + x = image.width - x + if offset > 0: + sign = -1 + else: + x = abs(x) + for i in range(abs(offset)): + color = abs(255 - int(255 * (i / abs(offset)))) # calculate grayscale color + i *= sign + draw.line([(x+i, 0), (x+i, image.height)], fill=(color, color, color)) + return image + +# def soften_mask(mask_before_blur, mask_img, blur_radius): +# # Apply Gaussian Blur to the mask +# blurred_mask = mask_img.filter(ImageFilter.GaussianBlur(blur_radius)) +# mask_before_blur = mask_before_blur.convert("L") + +# blurred_mask.paste(mask_before_blur, mask=mask_before_blur) + +# return blurred_mask + +def expand_image(img, expand_y=0, expand_x=0): + # Load the image + img = load_image(img) + width, height = img.size + + # Create a new image with expanded height + new_height = height + abs(expand_y) + new_width = width + abs(expand_x) + + new_img = Image.new('RGB', (new_width, new_height), color = 'white') + + # Create a mask image + mask_img = Image.new('1', (new_width, new_height), color = 'white') + + # If expand_y is positive, the image is expanded on the bottom. + # If expand_y is negative, the image is expanded on the top. + y_position = 0 if expand_y > 0 else abs(expand_y) + x_position = 0 if expand_x > 0 else abs(expand_x) + new_img.paste(img, (x_position, y_position)) + + # Create mask + mask_img.paste(Image.new('1', img.size, color = 'black'), (x_position, y_position)) + mask_img = mask_img.convert("RGB") + + # soft_mask_img = soften_mask(mask_img, mask_img, 50) + # return new_img, mask_img, soft_mask_img + + return new_img, mask_img + +if __name__ == '__main__': + # Usage: + path = "/home/erwann/diffusers/examples/community/castle.png" + expand = 256 + new_img, mask_img = expand_image(path, expand_x=expand) + new_img.save('new_image.png') + mask_img.save('mask_image.png') + # soft_mask.save('soft_mask.png') + softened_mask = create_gradient(mask_img, x=expand, offset=40) + softened_mask.save('soft_mask.png') + diff --git a/masked2.py b/masked2.py new file mode 100644 index 000000000000..e44997165998 --- /dev/null +++ b/masked2.py @@ -0,0 +1,353 @@ +import inspect +import warnings +from typing import * +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch + +import PIL +from diffusers import StableDiffusionImg2ImgPipeline +from diffusers.utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, + load_image, +) +from packaging import version +from PIL import Image +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + + +class StableDiffusionPipelineOutput: + def __init__(self, images): + self.images = images + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL.Image.LANCZOS))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionMaskedImg2ImgPipeline(StableDiffusionImg2ImgPipeline): + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker=False, + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker=False, + ) + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + unnoised_latents = init_latents.clone() + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + # return latents + return latents, noise, unnoised_latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + mask_image=None, + strength: float = 0.8, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + again. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # resize mask to apply to latents + mask_image = load_image(mask_image) + latent_mask = self.resize_mask(mask_image,) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Preprocess image + image = self.image_processor.preprocess(image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents, noise, init_latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + init_latents_proper = init_latents[:1] + # soft_mask = None + + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.add_noise( + init_latents_proper, noise, torch.tensor([noise_timestep]) + ) + + if latent_mask is None: + print("no mask") + # latents = (1 - init_mask) * init_latents_proper + init_mask * latents + else: + print("using mask") + latents = (1 - latent_mask) * init_latents_proper + latent_mask * latents + + + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image) \ No newline at end of file diff --git a/masked_img2img.py b/masked_img2img.py new file mode 100644 index 000000000000..a8d2a8913e61 --- /dev/null +++ b/masked_img2img.py @@ -0,0 +1,222 @@ +from typing import * + +import numpy as np +import torch + +import PIL +from diffusers import StableDiffusionImg2ImgPipeline + + +class StableDiffusionPipelineOutput: + def __init__(self, images): + self.images = images + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + + image = [np.array(i.resize((w, h), resample=PIL.Image.LANCZOS))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionMaskedImg2ImgPipeline(StableDiffusionImg2ImgPipeline): + def __init__( + self, + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker = True, + ): + super().__init__( + vae, + text_encoder, + tokenizer, + unet, + scheduler, + safety_checker, + feature_extractor, + requires_safety_checker = True, + ) + + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]], + image: Union[torch.FloatTensor, PIL.Image.Image], + mask_image, + strength: float = 0.8, + num_inference_steps: Optional[int] = 20, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: Optional[int] = 1, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 0.8): + Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` + will be used as a starting point, adding more noise to it the larger the `strength`. The number of + denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will + be maximum and the denoising process will run for the full number of iterations specified in + `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored + if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + og_image = image + og_mask = mask_image + print("overriden call") + + # 1. Check inputs + self.check_inputs(prompt, strength, callback_steps) + + # 2. Define call parameters + batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_embeddings = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt + ) + + # 4. Preprocess image + image = preprocess(image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + latents = self.prepare_latents( + image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + + # turn PIL mask_image into a tensor and resize it to the same size as latents + latent_mask = torch.tensor(np.array(mask_image.convert("L").resize(latents.shape[-2:][::-1]))).to(device) + latent_mask = latent_mask / 255.0 + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + init_latents = self.prepare_latents( + image, t, batch_size, num_images_per_prompt, text_embeddings.dtype, device, generator + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + latents = latents * latent_mask + init_latents * (1 - latent_mask) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + # image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) + + mask = np.asarray(og_mask).astype(np.float32)[None, ::] / 255 + + + image = (image*255).astype(np.float32) * mask + np.array(og_image).astype(np.float32) * (1-mask) + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image / 255) + + if not return_dict: + return (image, False) + + return StableDiffusionPipelineOutput(images=image) From 03da9552d20efe338b38d0caa31beab52297a918 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Wed, 21 Jun 2023 21:20:19 +0000 Subject: [PATCH 05/18] restore broken imports --- src/diffusers/pipelines/controlnet/pipeline_controlnet.py | 1 + .../pipeline_stable_diffusion_controlnet.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index b2a747becd7f..9e8b8ef7a29f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -22,6 +22,7 @@ from torch import nn import PIL.Image +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py index 4cf4084eb62c..30554abdb6f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_controlnet.py @@ -22,6 +22,8 @@ from torch import nn import PIL.Image + +# from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...loaders import TextualInversionLoaderMixin @@ -38,8 +40,8 @@ replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -# from . import StableDiffusionPipelineOutput -# from .safety_checker import StableDiffusionSafetyChecker +from . import StableDiffusionPipelineOutput +from .safety_checker import StableDiffusionSafetyChecker logger = logging.get_logger(__name__) # pylint: disable=invalid-name From 325ace2e8d3494d16fd3b5ff5e6cb457443bca69 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Mon, 10 Jul 2023 14:48:54 +0000 Subject: [PATCH 06/18] fix conflict in controlnet pipeline --- .../controlnet/pipeline_controlnet.py | 63 +++---------------- 1 file changed, 8 insertions(+), 55 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index cd03e53f8e3c..9e8b8ef7a29f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -515,8 +515,6 @@ def check_inputs( prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -603,27 +601,6 @@ def check_inputs( else: assert False - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if isinstance(self.controlnet, MultiControlNetModel): - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) @@ -850,6 +827,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, + img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + img2img_strength: float = 1.0, + controlnet_strength: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -925,10 +905,6 @@ def __call__( guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the controlnet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the controlnet stops applying. Examples: @@ -939,6 +915,8 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -951,8 +929,6 @@ def __call__( prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, ) # 2. Define call parameters @@ -969,17 +945,8 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) # 3. Encode input prompt prompt_embeds = self._encode_prompt( @@ -1080,15 +1047,6 @@ def __call__( # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) - # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1106,17 +1064,12 @@ def __call__( controlnet_latent_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - cond_scale = controlnet_conditioning_scale * controlnet_keep[i] - down_block_res_samples, mid_block_res_sample = self.controlnet( controlnet_latent_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, - conditioning_scale=cond_scale, + conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) From 5038961a5e67dc9e5b0596c976806b9468edac0b Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Thu, 27 Jul 2023 18:12:07 +0000 Subject: [PATCH 07/18] remove invisible watermark --- examples/dreambooth/README_sdxl.md | 1 + .../controlnet/pipeline_controlnet.py | 574 +++++++++++------- .../controlnet/pipeline_controlnet_sd_xl.py | 3 +- .../pipeline_if_img2img_superresolution.py | 3 +- .../pipeline_stable_diffusion_xl.py | 9 +- .../stable_diffusion_xl/watermark.py | 17 +- 6 files changed, 374 insertions(+), 233 deletions(-) diff --git a/examples/dreambooth/README_sdxl.md b/examples/dreambooth/README_sdxl.md index 7dcde78f2cfd..e850f598f08e 100644 --- a/examples/dreambooth/README_sdxl.md +++ b/examples/dreambooth/README_sdxl.md @@ -80,6 +80,7 @@ export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" export INSTANCE_DIR="dog" export OUTPUT_DIR="lora-trained-xl" +# python train_dreambooth_lora_sdxl.py \ accelerate launch train_dreambooth_lora_sdxl.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --instance_data_dir=$INSTANCE_DIR \ diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 4320240dac63..9e8b8ef7a29f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -14,31 +14,33 @@ import inspect -import warnings +import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch -import torch.nn.functional as F +from torch import nn + +import PIL.Image +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.controlnet import ControlNetOutput +from ...models.modeling_utils import ModelMixin +from ...pipelines.stable_diffusion import StableDiffusionPipelineOutput +from ...pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + PIL_INTERPOLATION, is_accelerate_available, is_accelerate_version, - is_compiled_module, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline -from ..stable_diffusion import StableDiffusionPipelineOutput -from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from .multicontrolnet import MultiControlNetModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -90,9 +92,66 @@ """ -class StableDiffusionControlNetPipeline( - DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin -): +class MultiControlNetModel(ModelMixin): + r""" + Multiple `ControlNetModel` wrapper class for Multi-ControlNet + + This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be + compatible with `ControlNetModel`. + + Args: + controlnets (`List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. You must set multiple + `ControlNetModel` as a list. + """ + + def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]): + super().__init__() + self.nets = nn.ModuleList(controlnets) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: List[torch.tensor], + conditioning_scale: List[float], + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[ControlNetOutput, Tuple]: + for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states, + image, + scale, + class_labels, + timestep_cond, + attention_mask, + cross_attention_kwargs, + guess_mode, + return_dict, + ) + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip(down_block_res_samples, down_samples) + ] + mid_block_res_sample += mid_sample + + return down_block_res_samples, mid_block_res_sample + + +class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin): r""" Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance. @@ -172,44 +231,46 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) - self.control_image_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False - ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): r""" - Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to - compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. """ self.vae.enable_slicing() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): r""" - Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to computing decoding in one step. """ self.vae.disable_slicing() - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling - def enable_vae_tiling(self): + def enable_sequential_cpu_offload(self, gpu_id=0): r""" - Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow - processing larger images. + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. """ - self.vae.enable_tiling() + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling - def disable_vae_tiling(self): - r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to - computing decoding in one step. - """ - self.vae.disable_tiling() + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) def enable_model_cpu_offload(self, gpu_id=0): r""" @@ -239,6 +300,25 @@ def enable_model_cpu_offload(self, gpu_id=0): # We'll offload the last model manually. self.final_offload_hook = hook + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( self, @@ -249,7 +329,6 @@ def _encode_prompt( negative_prompt=None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, ): r""" Encodes the prompt into text encoder hidden states. @@ -274,14 +353,7 @@ def _encode_prompt( Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -338,7 +410,7 @@ def _encode_prompt( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): + elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -396,28 +468,19 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) + else: + has_nsfw_concept = None return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): - warnings.warn( - "The decode_latents method is deprecated and will be removed in a future version. Please" - " use VaeImageProcessor instead", - FutureWarning, - ) latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() @@ -445,14 +508,17 @@ def check_inputs( self, prompt, image, + height, + width, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, controlnet_conditioning_scale=1.0, - control_guidance_start=0.0, - control_guidance_end=1.0, ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -497,20 +563,9 @@ def check_inputs( ) # Check `image` - is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( - self.controlnet, torch._dynamo.eval_frame.OptimizedModule - ) - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): + if isinstance(self.controlnet, ControlNetModel): self.check_image(image, prompt, prompt_embeds) - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): + elif isinstance(self.controlnet, MultiControlNetModel): if not isinstance(image, list): raise TypeError("For multiple controlnets: `image` must be type `list`") @@ -520,7 +575,7 @@ def check_inputs( raise ValueError("A single batch of multiple conditionings are supported at the moment.") elif len(image) != len(self.controlnet.nets): raise ValueError( - f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + "For multiple controlnets: `image` must have the same length as the number of controlnets." ) for image_ in image: @@ -529,18 +584,10 @@ def check_inputs( assert False # Check `controlnet_conditioning_scale` - if ( - isinstance(self.controlnet, ControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, ControlNetModel) - ): + if isinstance(self.controlnet, ControlNetModel): if not isinstance(controlnet_conditioning_scale, float): raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") - elif ( - isinstance(self.controlnet, MultiControlNetModel) - or is_compiled - and isinstance(self.controlnet._orig_mod, MultiControlNetModel) - ): + elif isinstance(self.controlnet, MultiControlNetModel): if isinstance(controlnet_conditioning_scale, list): if any(isinstance(i, list) for i in controlnet_conditioning_scale): raise ValueError("A single batch of multiple conditionings are supported at the moment.") @@ -554,50 +601,24 @@ def check_inputs( else: assert False - if len(control_guidance_start) != len(control_guidance_end): - raise ValueError( - f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." - ) - - if isinstance(self.controlnet, MultiControlNetModel): - if len(control_guidance_start) != len(self.controlnet.nets): - raise ValueError( - f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}." - ) - - for start, end in zip(control_guidance_start, control_guidance_end): - if start >= end: - raise ValueError( - f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." - ) - if start < 0.0: - raise ValueError(f"control guidance start: {start} can't be smaller than 0.") - if end > 1.0: - raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") - def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): + + if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list: raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors" ) if image_is_pil: image_batch_size = 1 - else: + elif image_is_tensor: + image_batch_size = image.shape[0] + elif image_is_pil_list: + image_batch_size = len(image) + elif image_is_tensor_list: image_batch_size = len(image) if prompt is not None and isinstance(prompt, str): @@ -624,7 +645,29 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + if not isinstance(image, torch.Tensor): + if isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + images = [] + + for image_ in image: + image_ = image_.convert("RGB") + image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) + image_ = np.array(image_) + image_ = image_[None, :] + images.append(image_) + + image = images + + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + image_batch_size = image.shape[0] if image_batch_size == 1: @@ -642,7 +685,15 @@ def prepare_image( return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -660,19 +711,104 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents_img2img(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if image is not None: + image = image.to(device=device, dtype=dtype) + + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + else: + latents = latents.to(device) + latents = latents * self.scheduler.init_noise_sigma + + + # scale the initial noise by the standard deviation required by the scheduler + # NOTE: not sure if the following line is necessary in for img2img + # latents = latents * self.scheduler.init_noise_sigma + return latents + + def _default_height_width(self, height, width, image): + # NOTE: It is possible that a list of images have different + # dimensions for each image, so just checking the first image + # is not _exactly_ correct, but it is simple. + while isinstance(image, list): + image = image[0] + + if height is None: + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + + height = (height // 8) * 8 # round down to nearest multiple of 8 + + if width is None: + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + + width = (width // 8) * 8 # round down to nearest multiple of 8 + + return height, width + + # override DiffusionPipeline + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + safe_serialization: bool = False, + variant: Optional[str] = None, + ): + if isinstance(self.controlnet, ControlNetModel): + super().save_pretrained(save_directory, safe_serialization, variant) + else: + raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.") + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -691,8 +827,9 @@ def __call__( cross_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_conditioning_scale: Union[float, List[float]] = 1.0, guess_mode: bool = False, - control_guidance_start: Union[float, List[float]] = 0.0, - control_guidance_end: Union[float, List[float]] = 1.0, + img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + img2img_strength: float = 1.0, + controlnet_strength: float = 1.0, ): r""" Function invoked when calling the pipeline for generation. @@ -701,8 +838,8 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: - `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If @@ -760,7 +897,7 @@ def __call__( cross_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added to the residual in the original unet. If multiple ControlNets are specified in init, you can set the @@ -768,10 +905,6 @@ def __call__( guess_mode (`bool`, *optional*, defaults to `False`): In this mode, the ControlNet encoder will try best to recognize the content of the input image even if you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. - control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): - The percentage of total steps at which the controlnet starts applying. - control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): - The percentage of total steps at which the controlnet stops applying. Examples: @@ -782,30 +915,20 @@ def __call__( list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - # align format for control guidance - if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): - control_guidance_start = len(control_guidance_end) * [control_guidance_start] - elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): - control_guidance_end = len(control_guidance_start) * [control_guidance_end] - elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): - mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 - control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ - control_guidance_end - ] + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, image) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, image, + height, + width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds, controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, ) # 2. Define call parameters @@ -822,20 +945,10 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions + if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets) # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) prompt_embeds = self._encode_prompt( prompt, device, @@ -844,11 +957,11 @@ def __call__( negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, ) + # 4. Prepare image - if isinstance(controlnet, ControlNetModel): + if isinstance(self.controlnet, ControlNetModel): image = self.prepare_image( image=image, width=width, @@ -856,12 +969,11 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) - height, width = image.shape[-2:] - elif isinstance(controlnet, MultiControlNetModel): + elif isinstance(self.controlnet, MultiControlNetModel): images = [] for image_ in image: @@ -872,7 +984,7 @@ def __call__( batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - dtype=controlnet.dtype, + dtype=self.controlnet.dtype, do_classifier_free_guidance=do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -880,39 +992,61 @@ def __call__( images.append(image_) image = images - height, width = image[0].shape[-2:] else: assert False # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - + # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + + if img2img_strength > 0.0: + img2img_image = self.prepare_image( + image=img2img_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + img2img_image = 2.0 * img2img_image - 1.0 + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1 - img2img_strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + latents = self.prepare_latents_img2img( + img2img_image[0][None, :], + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + else: + timesteps = self.scheduler.timesteps + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # 7.1 Create tensor stating which controlnets to keep - controlnet_keep = [] - for i in range(len(timesteps)): - keeps = [ - 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) - for s, e in zip(control_guidance_start, control_guidance_end) - ] - controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) - # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -924,24 +1058,18 @@ def __call__( # controlnet(s) inference if guess_mode and do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_latent_model_input = latents controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] else: - control_model_input = latent_model_input + controlnet_latent_model_input = latent_model_input controlnet_prompt_embeds = prompt_embeds - if isinstance(controlnet_keep[i], list): - cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] - else: - cond_scale = controlnet_conditioning_scale * controlnet_keep[i] - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, + controlnet_latent_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, controlnet_cond=image, - conditioning_scale=cond_scale, + conditioning_scale=controlnet_conditioning_scale, guess_mode=guess_mode, return_dict=False, ) @@ -954,15 +1082,22 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] + if t <= controlnet_strength * 1000: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + else: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + ).sample # perform guidance if do_classifier_free_guidance: @@ -970,7 +1105,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -985,19 +1120,24 @@ def __call__( self.controlnet.to("cpu") torch.cuda.empty_cache() - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: + if output_type == "latent": image = latents has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] + # 10. Convert to PIL + image = self.numpy_to_pil(image) else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + # 8. Post-processing + image = self.decode_latents(latents) - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: @@ -1006,4 +1146,4 @@ def __call__( if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 7017e7f14b53..9a743a40aa6a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -17,9 +17,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch import torch.nn.functional as F + +import PIL.Image from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor diff --git a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py index d00a19c92421..8160acbac656 100644 --- a/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +++ b/src/diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py @@ -5,9 +5,10 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import PIL import torch import torch.nn.functional as F + +import PIL from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer from ...loaders import LoraLoaderMixin diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index dd9d0d04e000..e1dbc6237a66 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -17,6 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch + from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor @@ -29,13 +30,7 @@ XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import ( - is_accelerate_available, - is_accelerate_version, - logging, - randn_tensor, - replace_example_docstring, -) +from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput from .watermark import StableDiffusionXLWatermarker diff --git a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py index bc6c9bf649b1..0c9a7e78568c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -1,5 +1,6 @@ import numpy as np import torch + from imwatermark import WatermarkEncoder @@ -17,15 +18,17 @@ def __init__(self): self.encoder.set_watermark("bits", self.watermark) def apply_watermark(self, images: torch.FloatTensor): + print("skipping watermark") + return images # can't encode images that are smaller than 256 - if images.shape[-1] < 256: - return images + # if images.shape[-1] < 256: + # return images - images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() + # images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy() - images = [self.encoder.encode(image, "dwtDct") for image in images] + # images = [self.encoder.encode(image, "dwtDct") for image in images] - images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) + # images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2) - images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) - return images + # images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0) + # return images From a8e444f4775d2f12c865b4760ec351a92df8d738 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Mon, 31 Jul 2023 17:33:10 +0000 Subject: [PATCH 08/18] krea hacks --- README.md | 2 +- docs/source/en/api/pipelines/auto_pipeline.md | 4 +- docs/source/en/api/pipelines/kandinsky.md | 79 +++ docs/source/en/api/pipelines/kandinsky_v22.md | 17 +- .../api/pipelines/stable_diffusion/adapter.md | 2 +- .../stable_diffusion/stable_diffusion_xl.md | 16 + docs/source/en/optimization/onnx.md | 4 +- docs/source/en/training/controlnet.md | 4 + docs/source/en/training/instructpix2pix.md | 2 +- docs/source/en/training/lora.md | 50 +- .../using-diffusers/controlling_generation.md | 1 + .../community/composable_stable_diffusion.py | 2 +- examples/controlnet/README_sdxl.md | 6 +- examples/controlnet/requirements_sdxl.txt | 1 - examples/controlnet/train_controlnet_sdxl.py | 7 +- examples/dreambooth/requirements_sdxl.txt | 1 - examples/dreambooth/train_dreambooth.py | 3 +- examples/dreambooth/train_dreambooth_lora.py | 9 +- .../dreambooth/train_dreambooth_lora_sdxl.py | 14 +- examples/instruct_pix2pix/README_sdxl.md | 53 +- ...x_xl.py => train_instruct_pix2pix_sdxl.py} | 0 src/diffusers/__init__.py | 19 +- src/diffusers/loaders.py | 533 +++++++++++++----- src/diffusers/models/attention_processor.py | 75 ++- src/diffusers/models/lora.py | 16 +- src/diffusers/models/resnet.py | 15 +- src/diffusers/models/transformer_2d.py | 6 +- src/diffusers/pipelines/__init__.py | 22 +- .../pipelines/controlnet/__init__.py | 11 +- .../controlnet/pipeline_controlnet_sd_xl.py | 23 +- src/diffusers/pipelines/pipeline_utils.py | 44 +- .../pipeline_semantic_stable_diffusion.py | 4 +- .../stable_diffusion/convert_from_ckpt.py | 3 +- .../pipeline_onnx_stable_diffusion.py | 1 + .../pipeline_onnx_stable_diffusion_img2img.py | 1 + .../pipeline_onnx_stable_diffusion_inpaint.py | 1 + ...ne_onnx_stable_diffusion_inpaint_legacy.py | 1 + .../pipeline_onnx_stable_diffusion_upscale.py | 2 + .../pipeline_stable_diffusion_upscale.py | 7 +- .../pipelines/stable_diffusion_xl/__init__.py | 5 +- .../pipeline_stable_diffusion_xl.py | 49 +- .../pipeline_stable_diffusion_xl_img2img.py | 24 +- .../pipeline_stable_diffusion_xl_inpaint.py | 31 +- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 23 +- .../stable_diffusion_xl/watermark.py | 8 + .../scheduling_dpmsolver_multistep.py | 2 + ...formers_and_invisible_watermark_objects.py | 77 --- .../dummy_torch_and_transformers_objects.py | 75 +++ tests/models/test_lora_layers.py | 7 +- tests/others/test_dependencies.py | 11 + .../test_stable_diffusion_upscale.py | 62 ++ .../test_stable_diffusion_xl.py | 4 +- .../test_stable_diffusion_xl_img2img.py | 19 +- .../test_stable_diffusion_xl_inpaint.py | 20 +- ...stable_diffusion_xl_instruction_pix2pix.py | 17 +- tests/pipelines/test_pipelines.py | 83 ++- tests/pipelines/test_pipelines_common.py | 4 +- 57 files changed, 1204 insertions(+), 378 deletions(-) rename examples/instruct_pix2pix/{train_instruct_pix2pix_xl.py => train_instruct_pix2pix_sdxl.py} (100%) delete mode 100644 src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py diff --git a/README.md b/README.md index 9307df83d7d6..ec6bddbc1fbf 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@


- +

diff --git a/docs/source/en/api/pipelines/auto_pipeline.md b/docs/source/en/api/pipelines/auto_pipeline.md index 4ae2b86ac269..65ea855405e9 100644 --- a/docs/source/en/api/pipelines/auto_pipeline.md +++ b/docs/source/en/api/pipelines/auto_pipeline.md @@ -39,8 +39,8 @@ Currently AutoPipeline support the Text-to-Image, Image-to-Image, and Inpainting - [Stable Diffusion Controlnet](./api/pipelines/controlnet) - [Stable Diffusion XL](./stable_diffusion/stable_diffusion_xl) - [IF](./if) -- [Kandinsky](./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky)(./kandinsky) -- [Kandinsky 2.2]()(./kandinsky) +- [Kandinsky](./kandinsky) +- [Kandinsky 2.2](./kandinsky) ## AutoPipelineForText2Image diff --git a/docs/source/en/api/pipelines/kandinsky.md b/docs/source/en/api/pipelines/kandinsky.md index 79c602b8bc07..5ebeaebf4e5e 100644 --- a/docs/source/en/api/pipelines/kandinsky.md +++ b/docs/source/en/api/pipelines/kandinsky.md @@ -105,6 +105,30 @@ One cheeseburger monster coming up! Enjoy! ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/cheeseburger.png) + + +We also provide an end-to-end Kandinsky pipeline [`KandinskyCombinedPipeline`], which combines both the prior pipeline and text-to-image pipeline, and lets you perform inference in a single step. You can create the combined pipeline with the [`~AutoPipelineForTextToImage.from_pretrained`] method + +```python +from diffusers import AutoPipelineForTextToImage +import torch + +pipe = AutoPipelineForTextToImage.from_pretrained( + "kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16 +) +pipe.enable_model_cpu_offload() +``` + +Under the hood, it will automatically load both [`KandinskyPriorPipeline`] and [`KandinskyPipeline`]. To generate images, you no longer need to call both pipelines and pass the outputs from one to another. You only need to call the combined pipeline once. You can set different `guidance_scale` and `num_inference_steps` for the prior pipeline with the `prior_guidance_scale` and `prior_num_inference_steps` arguments. + +```python +prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" +negative_prompt = "low quality, bad quality" + +image = pipe(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale =1.0, guidance_scacle = 4.0, height=768, width=768).images[0] +``` + + The Kandinsky model works extremely well with creative prompts. Here is some of the amazing art that can be created using the exact same process but with different prompts. ```python @@ -187,6 +211,34 @@ out.images[0].save("fantasy_land.png") ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/img2img_fantasyland.png) + + +You can also use the [`KandinskyImg2ImgCombinedPipeline`] for end-to-end image-to-image generation with Kandinsky 2.1 + +```python +from diffusers import AutoPipelineForImage2Image +import torch +import requests +from io import BytesIO +from PIL import Image +import os + +pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-2-1", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() + +prompt = "A fantasy landscape, Cinematic lighting" +negative_prompt = "low quality, bad quality" + +url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + +response = requests.get(url) +original_image = Image.open(BytesIO(response.content)).convert("RGB") +original_image.thumbnail((768, 768)) + +image = pipe(prompt=prompt, image=original_image, strength=0.3).images[0] +``` + + ### Text Guided Inpainting Generation You can use [`KandinskyInpaintPipeline`] to edit images. In this example, we will add a hat to the portrait of a cat. @@ -231,6 +283,33 @@ image.save("cat_with_hat.png") ``` ![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/kandinsky-docs/inpaint_cat_hat.png) + + +To use the [`KandinskyInpaintCombinedPipeline`] to perform end-to-end image inpainting generation, you can run below code instead + +```python +from diffusers import AutoPipelineForInpainting + +pipe = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-1-inpaint", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() +image = pipe(prompt=prompt, image=original_image, mask_image=mask).images[0] +``` + + +🚨🚨🚨 __Breaking change for Kandinsky Mask Inpainting__ 🚨🚨🚨 + +We introduced a breaking change for Kandinsky inpainting pipeline in the following pull request: https://github.com/huggingface/diffusers/pull/4207. Previously we accepted a mask format where black pixels represent the masked-out area. This is inconsistent with all other pipelines in diffusers. We have changed the mask format in Knaindsky and now using white pixels instead. +Please upgrade your inpainting code to follow the above. If you are using Kandinsky Inpaint in production. You now need to change the mask to: + +```python +# For PIL input +import PIL.ImageOps +mask = PIL.ImageOps.invert(mask) + +# For PyTorch and Numpy input +mask = 1 - mask +``` + ### Interpolate The [`KandinskyPriorPipeline`] also comes with a cool utility function that will allow you to interpolate the latent space of different images and texts super easily. Here is an example of how you can create an Impressionist-style portrait for your pet based on "The Starry Night". diff --git a/docs/source/en/api/pipelines/kandinsky_v22.md b/docs/source/en/api/pipelines/kandinsky_v22.md index 074bc5b8d64c..3f88997ff4f5 100644 --- a/docs/source/en/api/pipelines/kandinsky_v22.md +++ b/docs/source/en/api/pipelines/kandinsky_v22.md @@ -11,7 +11,22 @@ specific language governing permissions and limitations under the License. The Kandinsky 2.2 release includes robust new text-to-image models that support text-to-image generation, image-to-image generation, image interpolation, and text-guided image inpainting. The general workflow to perform these tasks using Kandinsky 2.2 is the same as in Kandinsky 2.1. First, you will need to use a prior pipeline to generate image embeddings based on your text prompt, and then use one of the image decoding pipelines to generate the output image. The only difference is that in Kandinsky 2.2, all of the decoding pipelines no longer accept the `prompt` input, and the image generation process is conditioned with only `image_embeds` and `negative_image_embeds`. -Let's look at an example of how to perform text-to-image generation using Kandinsky 2.2. +Same as with Kandinsky 2.1, the easiest way to perform text-to-image generation is to use the combined Kandinsky pipeline. This process is exactly the same as Kandinsky 2.1. All you need to do is to replace the Kandinsky 2.1 checkpoint with 2.2. + +```python +from diffusers import AutoPipelineForText2Image +import torch + +pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-2-2-decoder", torch_dtype=torch.float16) +pipe.enable_model_cpu_offload() + +prompt = "A alien cheeseburger creature eating itself, claymation, cinematic, moody lighting" +negative_prompt = "low quality, bad quality" + +image = pipe(prompt=prompt, negative_prompt=negative_prompt, prior_guidance_scale =1.0, height=768, width=768).images[0] +``` + +Now, let's look at an example where we take separate steps to run the prior pipeline and text-to-image pipeline. This way, we can understand what's happening under the hood and how Kandinsky 2.2 differs from Kandinsky 2.1. First, let's create the prior pipeline and text-to-image pipeline with Kandinsky 2.2 checkpoints. diff --git a/docs/source/en/api/pipelines/stable_diffusion/adapter.md b/docs/source/en/api/pipelines/stable_diffusion/adapter.md index 19351e1713b6..75b4f186e6be 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/adapter.md +++ b/docs/source/en/api/pipelines/stable_diffusion/adapter.md @@ -69,7 +69,7 @@ Next, create the adapter pipeline import torch from diffusers import StableDiffusionAdapterPipeline, T2IAdapter -adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1") +adapter = T2IAdapter.from_pretrained("TencentARC/t2iadapter_color_sd14v1", torch_dtype=torch.float16) pipe = StableDiffusionAdapterPipeline.from_pretrained( "CompVis/stable-diffusion-v1-4", adapter=adapter, diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md index 8da3e2f10727..8486641da2c4 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.md @@ -38,9 +38,25 @@ You can install the libraries as follows: pip install transformers pip install accelerate pip install safetensors +``` + +### Watermarker + +We recommend to add an invisible watermark to images generating by Stable Diffusion XL, this can help with identifying if an image is machine-synthesised for downstream applications. To do so, please install +the [invisible-watermark library](https://pypi.org/project/invisible-watermark/) via: + +``` pip install invisible-watermark>=0.2.0 ``` +If the `invisible-watermark` library is installed the watermarker will be used **by default**. + +If you have other provisions for generating or deploying images safely, you can disable the watermarker as follows: + +```py +pipe = StableDiffusionXLPipeline.from_pretrained(..., add_watermarker=False) +``` + ### Text-to-Image You can use SDXL as follows for *text-to-image*: diff --git a/docs/source/en/optimization/onnx.md b/docs/source/en/optimization/onnx.md index 89ea43521726..1eefc116cbf4 100644 --- a/docs/source/en/optimization/onnx.md +++ b/docs/source/en/optimization/onnx.md @@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License. --> -# How to use the ONNX Runtime for inference +# How to use ONNX Runtime for inference 🤗 [Optimum](https://github.com/huggingface/optimum) provides a Stable Diffusion pipeline compatible with ONNX Runtime. @@ -27,7 +27,7 @@ pip install optimum["onnxruntime"] ### Inference -To load an ONNX model and run inference with the ONNX Runtime, you need to replace [`StableDiffusionPipeline`] with `ORTStableDiffusionPipeline`. In case you want to load a PyTorch model and convert it to the ONNX format on-the-fly, you can set `export=True`. +To load an ONNX model and run inference with ONNX Runtime, you need to replace [`StableDiffusionPipeline`] with `ORTStableDiffusionPipeline`. In case you want to load a PyTorch model and convert it to the ONNX format on-the-fly, you can set `export=True`. ```python from optimum.onnxruntime import ORTStableDiffusionPipeline diff --git a/docs/source/en/training/controlnet.md b/docs/source/en/training/controlnet.md index 16a9ba95f057..b2b75f7cf110 100644 --- a/docs/source/en/training/controlnet.md +++ b/docs/source/en/training/controlnet.md @@ -327,3 +327,7 @@ image = pipe(prompt, num_inference_steps=20, generator=generator, image=control_ image.save("./output.png") ``` + +## Stable Diffusion XL + +Training with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) is also supported via the `train_controlnet_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sdxl.md). \ No newline at end of file diff --git a/docs/source/en/training/instructpix2pix.md b/docs/source/en/training/instructpix2pix.md index f9557c7bea2e..4a8c738c1076 100644 --- a/docs/source/en/training/instructpix2pix.md +++ b/docs/source/en/training/instructpix2pix.md @@ -212,4 +212,4 @@ If you're looking for some interesting ways to use the InstructPix2Pix training ## Stable Diffusion XL -We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_dreambooth_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md). \ No newline at end of file +Training with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) is also supported via the `train_instruct_pix2pix_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md). \ No newline at end of file diff --git a/docs/source/en/training/lora.md b/docs/source/en/training/lora.md index 670a94658160..fd88d74854b2 100644 --- a/docs/source/en/training/lora.md +++ b/docs/source/en/training/lora.md @@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: lora_model_id = "sayakpaul/civitai-light-shadow-lora" lora_filename = "light_and_shadow.safetensors" pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) -``` \ No newline at end of file +``` + +### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer + +With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL). + +Here are some example checkpoints we tried out: + +* SDXL 0.9: + * https://civitai.com/models/22279?modelVersionId=118556 + * https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora + * https://civitai.com/models/108448/daiton-sdxl-test + * https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors +* SDXL 1.0: + * https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors + +Here is an example of how to perform inference with these checkpoints in `diffusers`: + +```python +from diffusers import DiffusionPipeline +import torch + +base_model_id = "stabilityai/stable-diffusion-xl-base-0.9" +pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda") +pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors") + +prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint " +negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions" +generator = torch.manual_seed(2947883060) +num_inference_steps = 30 +guidance_scale = 7 + +image = pipeline( + prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, + generator=generator, guidance_scale=guidance_scale +).images[0] +image.save("Kamepan.png") +``` + +`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 . + +If you notice carefully, the inference UX is exactly identical to what we presented in the sections above. + +Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature. + +### Known limitations specific to the Kohya-styled LoRAs + +* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue. +* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736). \ No newline at end of file diff --git a/docs/source/en/using-diffusers/controlling_generation.md b/docs/source/en/using-diffusers/controlling_generation.md index b4b3a9bbcc48..9c52de573d08 100644 --- a/docs/source/en/using-diffusers/controlling_generation.md +++ b/docs/source/en/using-diffusers/controlling_generation.md @@ -40,6 +40,7 @@ Unless otherwise mentioned, these are techniques that work with existing models 12. [Custom Diffusion](#custom-diffusion) 13. [Model Editing](#model-editing) 14. [DiffEdit](#diffedit) +15. [T2I-Adapter](#t2i-adapter) For convenience, we provide a table to denote which methods are inference-only and which require fine-tuning/training. diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index 95292f5bdae8..8a2263b096c3 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -423,7 +423,7 @@ def __call__( num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): + guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > diff --git a/examples/controlnet/README_sdxl.md b/examples/controlnet/README_sdxl.md index db8dada65427..4a7797b9572c 100644 --- a/examples/controlnet/README_sdxl.md +++ b/examples/controlnet/README_sdxl.md @@ -1,6 +1,6 @@ -# DreamBooth training example for Stable Diffusion XL (SDXL) +# ControlNet training example for Stable Diffusion XL (SDXL) -The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). +The `train_controlnet_sdxl.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). ## Running locally with PyTorch @@ -128,4 +128,4 @@ image.save("./output.png") ### Specifying a better VAE -SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). \ No newline at end of file +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). diff --git a/examples/controlnet/requirements_sdxl.txt b/examples/controlnet/requirements_sdxl.txt index 0192e03ddb4c..5ab6e9932e10 100644 --- a/examples/controlnet/requirements_sdxl.txt +++ b/examples/controlnet/requirements_sdxl.txt @@ -4,6 +4,5 @@ transformers>=4.25.1 ftfy tensorboard Jinja2 -invisible-watermark>=0.2.0 datasets wandb diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index dfe12e352791..6be07a38056f 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -210,7 +210,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N yaml = f""" --- -license: creativeml-openrail-m +license: openrail++ base_model: {base_model} tags: - stable-diffusion-xl @@ -227,12 +227,7 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N These are controlnet weights trained on {base_model} with new type of conditioning. {img_str} """ - model_card += """ -## License - -[SDXL 1.0 License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md) -""" with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) diff --git a/examples/dreambooth/requirements_sdxl.txt b/examples/dreambooth/requirements_sdxl.txt index 7a9936ee4003..7a612982f4ab 100644 --- a/examples/dreambooth/requirements_sdxl.txt +++ b/examples/dreambooth/requirements_sdxl.txt @@ -4,4 +4,3 @@ transformers>=4.25.1 ftfy tensorboard Jinja2 -invisible-watermark>=0.2.0 \ No newline at end of file diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index c3789406a3fb..f5e904d91c02 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import copy import gc import hashlib import itertools @@ -1116,7 +1117,7 @@ def compute_text_embeddings(prompt): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_config = vars(args) + tracker_config = vars(copy.deepcopy(args)) tracker_config.pop("validation_images") accelerator.init_trackers("dreambooth", config=tracker_config) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index a8aaf38158db..f6c9990a37f3 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import copy import gc import hashlib import itertools @@ -924,10 +925,10 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1067,7 +1068,7 @@ def compute_text_embeddings(prompt): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_config = vars(args) + tracker_config = vars(copy.deepcopy(args)) tracker_config.pop("validation_images") accelerator.init_trackers("dreambooth-lora", config=tracker_config) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 38db36e5b670..0383ab4b9928 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -73,7 +73,7 @@ def save_model_card( yaml = f""" --- -license: creativeml-openrail-m +license: openrail++ base_model: {base_model} instance_prompt: {prompt} tags: @@ -94,10 +94,6 @@ def save_model_card( LoRA for the text encoder was enabled: {train_text_encoder}. Special VAE used for training: {vae_path}. - -## License - -[SDXL 1.0 License](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md) """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -829,13 +825,13 @@ def load_model_hook(models, input_dir): else: raise ValueError(f"unexpected save model: {model.__class__}") - lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir) - LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_) + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ ) LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_ + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) diff --git a/examples/instruct_pix2pix/README_sdxl.md b/examples/instruct_pix2pix/README_sdxl.md index 8e3e6c881235..3d521916b47b 100644 --- a/examples/instruct_pix2pix/README_sdxl.md +++ b/examples/instruct_pix2pix/README_sdxl.md @@ -4,9 +4,9 @@ [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) (or SDXL) is the latest image generation model that is tailored towards more photorealistic outputs with more detailed imagery and composition compared to previous SD models. It leverages a three times larger UNet backbone. The increase of model parameters is mainly due to more attention blocks and a larger cross-attention context as SDXL uses a second text encoder. -The `train_instruct_pix2pix_xl.py` script shows how to implement the training procedure and adapt it for Stable Diffusion XL. +The `train_instruct_pix2pix_sdxl.py` script shows how to implement the training procedure and adapt it for Stable Diffusion XL. -***Disclaimer: Even though `train_instruct_pix2pix_xl.py` implements the InstructPix2Pix +***Disclaimer: Even though `train_instruct_pix2pix_sdxl.py` implements the InstructPix2Pix training procedure while being faithful to the [original implementation](https://github.com/timothybrooks/instruct-pix2pix) we have only tested it on a [small-scale dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples). This can impact the end results. For better results, we recommend longer training runs with a larger dataset. [Here](https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered) you can find a large dataset for InstructPix2Pix training.*** ## Running locally with PyTorch @@ -33,7 +33,7 @@ export DATASET_ID="fusing/instructpix2pix-1000-samples" Now, we can launch training: ```bash -python train_instruct_pix2pix_xl.py \ +python train_instruct_pix2pix_sdxl.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --dataset_name=$DATASET_ID \ --enable_xformers_memory_efficient_attention \ @@ -50,7 +50,7 @@ Additionally, we support performing validation inference to monitor training pro with Weights and Biases. You can enable this feature with `report_to="wandb"`: ```bash -python train_instruct_pix2pix_xl.py \ +python train_instruct_pix2pix_sdxl.py \ --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \ --dataset_name=$DATASET_ID \ --use_ema \ @@ -146,3 +146,48 @@ Particularly, `image_guidance_scale` and `guidance_scale` can have a profound im on the generated ("edited") image (see [here](https://twitter.com/RisingSayak/status/1628392199196151808?s=20) for an example). If you're looking for some interesting ways to use the InstructPix2Pix training methodology, we welcome you to check out this blog post: [Instruction-tuning Stable Diffusion with InstructPix2Pix](https://huggingface.co/blog/instruction-tuning-sd). + +## Compare between SD and SDXL + +We aim to understand the differences resulting from the use of SD-1.5 and SDXL-0.9 as pretrained models. To achieve this, we trained on the [small toy dataset](https://huggingface.co/datasets/fusing/instructpix2pix-1000-samples) using both of these pretrained models. The training script is as follows: + +```bash +export MODEL_NAME="runwayml/stable-diffusion-v1-5" or "stabilityai/stable-diffusion-xl-base-0.9" +export DATASET_ID="fusing/instructpix2pix-1000-samples" + +CUDA_VISIBLE_DEVICES=1 python train_instruct_pix2pix.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --dataset_name=$DATASET_ID \ + --use_ema \ + --enable_xformers_memory_efficient_attention \ + --resolution=512 --random_flip \ + --train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \ + --max_train_steps=15000 \ + --checkpointing_steps=5000 --checkpoints_total_limit=1 \ + --learning_rate=5e-05 --lr_warmup_steps=0 \ + --conditioning_dropout_prob=0.05 \ + --seed=42 \ + --val_image_url="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \ + --validation_prompt="make it in Japan" \ + --report_to=wandb +``` + +We discovered that compared to training with SD-1.5 as the pretrained model, SDXL-0.9 results in a lower training loss value (SD-1.5 yields 0.0599, SDXL scores 0.0254). Moreover, from a visual perspective, the results obtained using SDXL demonstrated fewer artifacts and a richer detail. Notably, SDXL starts to preserve the structure of the original image earlier on. + +The following two GIFs provide intuitive visual results. We observed, for each step, what kind of results could be achieved using the image +

+ input for make it Japan +

+with "make it in Japan” as the prompt. It can be seen that SDXL starts preserving the details of the original image earlier, resulting in higher fidelity outcomes sooner. + +* SD-1.5: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sd_ip2p_training_val_img_progress.gif + +

+ input for make it Japan +

+ +* SDXL: https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl_ip2p_training_val_img_progress.gif + +

+ input for make it Japan +

diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_xl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py similarity index 100% rename from examples/instruct_pix2pix/train_instruct_pix2pix_xl.py rename to examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ba666fef13ed..3db2683640fd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -185,6 +185,11 @@ StableDiffusionPix2PixZeroPipeline, StableDiffusionSAGPipeline, StableDiffusionUpscalePipeline, + StableDiffusionXLControlNetPipeline, + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, TextToVideoSDPipeline, @@ -202,20 +207,6 @@ VQDiffusionPipeline, ) -try: - if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 -else: - from .pipelines import ( - StableDiffusionXLControlNetPipeline, - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLPipeline, - ) - try: if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e5b506259172..6a6e03117e64 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import re import warnings from collections import defaultdict from contextlib import nullcontext @@ -56,7 +57,6 @@ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" -TOTAL_EXAMPLE_KEYS = 5 TEXT_INVERSION_NAME = "learned_embeds.bin" TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" @@ -257,7 +257,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict use_safetensors = kwargs.pop("use_safetensors", None) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning - network_alpha = kwargs.pop("network_alpha", None) + network_alphas = kwargs.pop("network_alphas", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -322,7 +322,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processors = {} non_attn_lora_layers = [] - is_lora = all("lora" in k for k in state_dict.keys()) + is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) if is_lora: @@ -339,10 +339,25 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} lora_grouped_dict = defaultdict(dict) - for key, value in state_dict.items(): + mapped_network_alphas = {} + + all_keys = list(state_dict.keys()) + for key in all_keys: + value = state_dict.pop(key) attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) lora_grouped_dict[attn_processor_key][sub_key] = value + # Create another `mapped_network_alphas` dictionary so that we can properly map them. + if network_alphas is not None: + for k in network_alphas: + if k.replace(".alpha", "") in key: + mapped_network_alphas.update({attn_processor_key: network_alphas[k]}) + + if len(state_dict) > 0: + raise ValueError( + f"The state_dict has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" + ) + for key, value_dict in lora_grouped_dict.items(): attn_processor = self for sub_key in key.split("."): @@ -352,13 +367,27 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # or add_{k,v,q,out_proj}_proj_lora layers. if "lora.down.weight" in value_dict: rank = value_dict["lora.down.weight"].shape[0] - hidden_size = value_dict["lora.up.weight"].shape[0] if isinstance(attn_processor, LoRACompatibleConv): - lora = LoRAConv2dLayer(hidden_size, hidden_size, rank, network_alpha) + in_features = attn_processor.in_channels + out_features = attn_processor.out_channels + kernel_size = attn_processor.kernel_size + + lora = LoRAConv2dLayer( + in_features=in_features, + out_features=out_features, + rank=rank, + kernel_size=kernel_size, + stride=attn_processor.stride, + padding=attn_processor.padding, + network_alpha=mapped_network_alphas.get(key), + ) elif isinstance(attn_processor, LoRACompatibleLinear): lora = LoRALinearLayer( - attn_processor.in_features, attn_processor.out_features, rank, network_alpha + attn_processor.in_features, + attn_processor.out_features, + rank, + mapped_network_alphas.get(key), ) else: raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.") @@ -366,32 +395,64 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} lora.load_state_dict(value_dict) non_attn_lora_layers.append((attn_processor, lora)) - continue - - rank = value_dict["to_k_lora.down.weight"].shape[0] - hidden_size = value_dict["to_k_lora.up.weight"].shape[0] - - if isinstance( - attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) - ): - cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] - attn_processor_class = LoRAAttnAddedKVProcessor else: - cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] - if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): - attn_processor_class = LoRAXFormersAttnProcessor + # To handle SDXL. + rank_mapping = {} + hidden_size_mapping = {} + for projection_id in ["to_k", "to_q", "to_v", "to_out"]: + rank = value_dict[f"{projection_id}_lora.down.weight"].shape[0] + hidden_size = value_dict[f"{projection_id}_lora.up.weight"].shape[0] + + rank_mapping.update({f"{projection_id}_lora.down.weight": rank}) + hidden_size_mapping.update({f"{projection_id}_lora.up.weight": hidden_size}) + + if isinstance( + attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0) + ): + cross_attention_dim = value_dict["add_k_proj_lora.down.weight"].shape[1] + attn_processor_class = LoRAAttnAddedKVProcessor + else: + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + if isinstance(attn_processor, (XFormersAttnProcessor, LoRAXFormersAttnProcessor)): + attn_processor_class = LoRAXFormersAttnProcessor + else: + attn_processor_class = ( + LoRAAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else LoRAAttnProcessor + ) + + if attn_processor_class is not LoRAAttnAddedKVProcessor: + attn_processors[key] = attn_processor_class( + rank=rank_mapping.get("to_k_lora.down.weight"), + hidden_size=hidden_size_mapping.get("to_k_lora.up.weight"), + cross_attention_dim=cross_attention_dim, + network_alpha=mapped_network_alphas.get(key), + q_rank=rank_mapping.get("to_q_lora.down.weight"), + q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight"), + v_rank=rank_mapping.get("to_v_lora.down.weight"), + v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight"), + out_rank=rank_mapping.get("to_out_lora.down.weight"), + out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight"), + # rank=rank_mapping.get("to_k_lora.down.weight", None), + # hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None), + # q_rank=rank_mapping.get("to_q_lora.down.weight", None), + # q_hidden_size=hidden_size_mapping.get("to_q_lora.up.weight", None), + # v_rank=rank_mapping.get("to_v_lora.down.weight", None), + # v_hidden_size=hidden_size_mapping.get("to_v_lora.up.weight", None), + # out_rank=rank_mapping.get("to_out_lora.down.weight", None), + # out_hidden_size=hidden_size_mapping.get("to_out_lora.up.weight", None), + ) else: - attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + attn_processors[key] = attn_processor_class( + rank=rank_mapping.get("to_k_lora.down.weight", None), + hidden_size=hidden_size_mapping.get("to_k_lora.up.weight", None), + cross_attention_dim=cross_attention_dim, + network_alpha=mapped_network_alphas.get(key), ) - attn_processors[key] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - rank=rank, - network_alpha=network_alpha, - ) - attn_processors[key].load_state_dict(value_dict) + attn_processors[key].load_state_dict(value_dict) + elif is_custom_diffusion: custom_diffusion_grouped_dict = defaultdict(dict) for key, value in state_dict.items(): @@ -434,8 +495,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict # set ff layers for target_module, lora_layer in non_attn_lora_layers: - if hasattr(target_module, "set_lora_layer"): - target_module.set_lora_layer(lora_layer) + target_module.set_lora_layer(lora_layer) + # It should raise an error if we don't have a set lora here + # if hasattr(target_module, "set_lora_layer"): + # target_module.set_lora_layer(lora_layer) def save_attn_procs( self, @@ -873,11 +936,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di kwargs (`dict`, *optional*): See [`~loaders.LoraLoaderMixin.lora_state_dict`]. """ - state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) self.load_lora_into_text_encoder( state_dict, - network_alpha=network_alpha, + network_alphas=network_alphas, text_encoder=self.text_encoder, lora_scale=self.lora_scale, ) @@ -889,7 +952,7 @@ def lora_state_dict( **kwargs, ): r""" - Return state dict for lora weights + Return state dict for lora weights and the network alphas. @@ -950,6 +1013,7 @@ def lora_state_dict( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) + unet_config = kwargs.pop("unet_config", None) use_safetensors = kwargs.pop("use_safetensors", None) if use_safetensors and not is_safetensors_available(): @@ -1011,53 +1075,158 @@ def lora_state_dict( else: state_dict = pretrained_model_name_or_path_or_dict - # Convert kohya-ss Style LoRA attn procs to diffusers attn procs - network_alpha = None - if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): - state_dict, network_alpha = cls._convert_kohya_lora_to_diffusers(state_dict) + network_alphas = None + if all( + ( + k.startswith("lora_te_") + or k.startswith("lora_unet_") + or k.startswith("lora_te1_") + or k.startswith("lora_te2_") + ) + for k in state_dict.keys() + ): + # Map SDXL blocks correctly. + if unet_config is not None: + # use unet config to remap block numbers + state_dict = cls._map_sgm_blocks_to_diffusers(state_dict, unet_config) + state_dict, network_alphas = cls._convert_kohya_lora_to_diffusers(state_dict) - return state_dict, network_alpha + return state_dict, network_alphas @classmethod - def load_lora_into_unet(cls, state_dict, network_alpha, unet): + def _map_sgm_blocks_to_diffusers(cls, state_dict, unet_config, delimiter="_", block_slice_pos=5): + is_all_unet = all(k.startswith("lora_unet") for k in state_dict) + new_state_dict = {} + inner_block_map = ["resnets", "attentions", "upsamplers"] + + # Retrieves # of down, mid and up blocks + input_block_ids, middle_block_ids, output_block_ids = set(), set(), set() + for layer in state_dict: + if "text" not in layer: + layer_id = int(layer.split(delimiter)[:block_slice_pos][-1]) + if "input_blocks" in layer: + input_block_ids.add(layer_id) + elif "middle_block" in layer: + middle_block_ids.add(layer_id) + elif "output_blocks" in layer: + output_block_ids.add(layer_id) + else: + raise ValueError("Checkpoint not supported") + + input_blocks = { + layer_id: [key for key in state_dict if f"input_blocks{delimiter}{layer_id}" in key] + for layer_id in input_block_ids + } + middle_blocks = { + layer_id: [key for key in state_dict if f"middle_block{delimiter}{layer_id}" in key] + for layer_id in middle_block_ids + } + output_blocks = { + layer_id: [key for key in state_dict if f"output_blocks{delimiter}{layer_id}" in key] + for layer_id in output_block_ids + } + + # Rename keys accordingly + for i in input_block_ids: + block_id = (i - 1) // (unet_config.layers_per_block + 1) + layer_in_block_id = (i - 1) % (unet_config.layers_per_block + 1) + + for key in input_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] if "op" not in key else "downsamplers" + inner_layers_in_block = str(layer_in_block_id) if "op" not in key else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in middle_block_ids: + key_part = None + if i == 0: + key_part = [inner_block_map[0], "0"] + elif i == 1: + key_part = [inner_block_map[1], "0"] + elif i == 2: + key_part = [inner_block_map[0], "1"] + else: + raise ValueError(f"Invalid middle block id {i}.") + + for key in middle_blocks[i]: + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + key_part + key.split(delimiter)[block_slice_pos:] + ) + new_state_dict[new_key] = state_dict.pop(key) + + for i in output_block_ids: + block_id = i // (unet_config.layers_per_block + 1) + layer_in_block_id = i % (unet_config.layers_per_block + 1) + + for key in output_blocks[i]: + inner_block_id = int(key.split(delimiter)[block_slice_pos]) + inner_block_key = inner_block_map[inner_block_id] + inner_layers_in_block = str(layer_in_block_id) if inner_block_id < 2 else "0" + new_key = delimiter.join( + key.split(delimiter)[: block_slice_pos - 1] + + [str(block_id), inner_block_key, inner_layers_in_block] + + key.split(delimiter)[block_slice_pos + 1 :] + ) + new_state_dict[new_key] = state_dict.pop(key) + + if is_all_unet and len(state_dict) > 0: + raise ValueError("At this point all state dict entries have to be converted.") + else: + # Remaining is the text encoder state dict. + for k, v in state_dict.items(): + new_state_dict.update({k: v}) + + return new_state_dict + + @classmethod + def load_lora_into_unet(cls, state_dict, network_alphas, unet): """ - This will load the LoRA layers specified in `state_dict` into `unet` + This will load the LoRA layers specified in `state_dict` into `unet`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. - network_alpha (`float`): + network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. unet (`UNet2DConditionModel`): The UNet model to load the LoRA layers into. """ - # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. keys = list(state_dict.keys()) + if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): # Load the layers corresponding to UNet. - unet_keys = [k for k in keys if k.startswith(cls.unet_name)] logger.info(f"Loading {cls.unet_name}.") - unet_lora_state_dict = { - k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys - } - unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) - # Otherwise, we're dealing with the old format. This means the `state_dict` should only - # contain the module names of the `unet` as its keys WITHOUT any prefix. - elif not all( - key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in state_dict.keys() - ): - unet.load_attn_procs(state_dict, network_alpha=network_alpha) + unet_keys = [k for k in keys if k.startswith(cls.unet_name)] + state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)] + network_alphas = { + k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + else: + # Otherwise, we're dealing with the old format. This means the `state_dict` should only + # contain the module names of the `unet` as its keys WITHOUT any prefix. warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) + # load loras into unet + unet.load_attn_procs(state_dict, network_alphas=network_alphas) + @classmethod - def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0): + def load_lora_into_text_encoder(cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0): """ This will load the LoRA layers specified in `state_dict` into `text_encoder` @@ -1065,7 +1234,7 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr state_dict (`dict`): A standard state dict containing the lora layer parameters. The key should be prefixed with an additional `text_encoder` to distinguish between unet lora layers. - network_alpha (`float`): + network_alphas (`Dict[str, float]`): See `LoRALinearLayer` for more details. text_encoder (`CLIPTextModel`): The text encoder model to load the LoRA layers into. @@ -1134,14 +1303,19 @@ def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, pr ].shape[1] patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys()) - cls._modify_text_encoder(text_encoder, lora_scale, network_alpha, rank=rank, patch_mlp=patch_mlp) + cls._modify_text_encoder( + text_encoder, + lora_scale, + network_alphas, + rank=rank, + patch_mlp=patch_mlp, + ) # set correct dtype & device text_encoder_lora_state_dict = { k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) for k, v in text_encoder_lora_state_dict.items() } - load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) if len(load_state_dict_results.unexpected_keys) != 0: raise ValueError( @@ -1176,7 +1350,7 @@ def _modify_text_encoder( cls, text_encoder, lora_scale=1, - network_alpha=None, + network_alphas=None, rank=4, dtype=None, patch_mlp=False, @@ -1189,37 +1363,46 @@ def _modify_text_encoder( cls._remove_text_encoder_monkey_patch_classmethod(text_encoder) lora_parameters = [] + network_alphas = {} if network_alphas is None else network_alphas + + for name, attn_module in text_encoder_attn_modules(text_encoder): + query_alpha = network_alphas.get(name + ".k.proj.alpha") + key_alpha = network_alphas.get(name + ".q.proj.alpha") + value_alpha = network_alphas.get(name + ".v.proj.alpha") + proj_alpha = network_alphas.get(name + ".out.proj.alpha") - for _, attn_module in text_encoder_attn_modules(text_encoder): attn_module.q_proj = PatchedLoraProjection( - attn_module.q_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.q_proj, lora_scale, network_alpha=query_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.q_proj.lora_linear_layer.parameters()) attn_module.k_proj = PatchedLoraProjection( - attn_module.k_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.k_proj, lora_scale, network_alpha=key_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.k_proj.lora_linear_layer.parameters()) attn_module.v_proj = PatchedLoraProjection( - attn_module.v_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.v_proj, lora_scale, network_alpha=value_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.v_proj.lora_linear_layer.parameters()) attn_module.out_proj = PatchedLoraProjection( - attn_module.out_proj, lora_scale, network_alpha, rank=rank, dtype=dtype + attn_module.out_proj, lora_scale, network_alpha=proj_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(attn_module.out_proj.lora_linear_layer.parameters()) if patch_mlp: - for _, mlp_module in text_encoder_mlp_modules(text_encoder): + for name, mlp_module in text_encoder_mlp_modules(text_encoder): + fc1_alpha = network_alphas.get(name + ".fc1.alpha") + fc2_alpha = network_alphas.get(name + ".fc2.alpha") + mlp_module.fc1 = PatchedLoraProjection( - mlp_module.fc1, lora_scale, network_alpha, rank=rank, dtype=dtype + mlp_module.fc1, lora_scale, network_alpha=fc1_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(mlp_module.fc1.lora_linear_layer.parameters()) mlp_module.fc2 = PatchedLoraProjection( - mlp_module.fc2, lora_scale, network_alpha, rank=rank, dtype=dtype + mlp_module.fc2, lora_scale, network_alpha=fc2_alpha, rank=rank, dtype=dtype ) lora_parameters.extend(mlp_module.fc2.lora_linear_layer.parameters()) @@ -1326,77 +1509,163 @@ def save_function(weights, filename): def _convert_kohya_lora_to_diffusers(cls, state_dict): unet_state_dict = {} te_state_dict = {} - network_alpha = None - unloaded_keys = [] - - for key, value in state_dict.items(): - if "hada" in key or "skip" in key: - unloaded_keys.append(key) - elif "lora_down" in key: - lora_name = key.split(".")[0] - lora_name_up = lora_name + ".lora_up.weight" - lora_name_alpha = lora_name + ".alpha" - if lora_name_alpha in state_dict: - alpha = state_dict[lora_name_alpha].item() - if network_alpha is None: - network_alpha = alpha - elif network_alpha != alpha: - raise ValueError("Network alpha is not consistent") - - if lora_name.startswith("lora_unet_"): - diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + te2_state_dict = {} + network_alphas = {} + + # every down weight has a corresponding up weight and potentially an alpha weight + lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")] + for key in lora_keys: + lora_name = key.split(".")[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + + # if lora_name_alpha in state_dict: + # alpha = state_dict.pop(lora_name_alpha).item() + # network_alphas.update({lora_name_alpha: alpha}) + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + + if "input.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("input.blocks", "down_blocks") + else: diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + + if "middle.block" in diffusers_name: + diffusers_name = diffusers_name.replace("middle.block", "mid_block") + else: diffusers_name = diffusers_name.replace("mid.block", "mid_block") + if "output.blocks" in diffusers_name: + diffusers_name = diffusers_name.replace("output.blocks", "up_blocks") + else: diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - diffusers_name = diffusers_name.replace("proj.in", "proj_in") - diffusers_name = diffusers_name.replace("proj.out", "proj_out") - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif "ff" in diffusers_name: - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif any(key in diffusers_name for key in ("proj_in", "proj_out")): - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - - elif lora_name.startswith("lora_te_"): - diffusers_name = key.replace("lora_te_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("text.model", "text_model") - diffusers_name = diffusers_name.replace("self.attn", "self_attn") - diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - if "self_attn" in diffusers_name: - te_state_dict[diffusers_name] = value - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - elif "mlp" in diffusers_name: - # Be aware that this is the new diffusers convention and the rest of the code might - # not utilize it yet. - diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") - te_state_dict[diffusers_name] = value - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - logger.info("Kohya-style checkpoint detected.") - if len(unloaded_keys) > 0: - example_unloaded_keys = ", ".join(x for x in unloaded_keys[:TOTAL_EXAMPLE_KEYS]) - logger.warning( - f"There are some keys (such as: {example_unloaded_keys}) in the checkpoints we don't provide support for." + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + diffusers_name = diffusers_name.replace("proj.in", "proj_in") + diffusers_name = diffusers_name.replace("proj.out", "proj_out") + diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj") + + # SDXL specificity. + if "emb" in diffusers_name: + pattern = r"\.\d+(?=\D*$)" + diffusers_name = re.sub(pattern, "", diffusers_name, count=1) + if ".in." in diffusers_name: + diffusers_name = diffusers_name.replace("in.layers.2", "conv1") + if ".out." in diffusers_name: + diffusers_name = diffusers_name.replace("out.layers.3", "conv2") + if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name: + diffusers_name = diffusers_name.replace("op", "conv") + if "skip" in diffusers_name: + diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut") + + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "ff" in diffusers_name: + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif any(key in diffusers_name for key in ("proj_in", "proj_out")): + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + else: + unet_state_dict[diffusers_name] = state_dict.pop(key) + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # (sayakpaul): Duplicate code. Needs to be cleaned. + elif lora_name.startswith("lora_te1_"): + diffusers_name = key.replace("lora_te1_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te_state_dict[diffusers_name] = state_dict.pop(key) + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # (sayakpaul): Duplicate code. Needs to be cleaned. + elif lora_name.startswith("lora_te2_"): + diffusers_name = key.replace("lora_te2_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + elif "mlp" in diffusers_name: + # Be aware that this is the new diffusers convention and the rest of the code might + # not utilize it yet. + diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.") + te2_state_dict[diffusers_name] = state_dict.pop(key) + te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up) + + # Rename the alphas so that they can be mapped appropriately. + if lora_name_alpha in state_dict: + alpha = state_dict.pop(lora_name_alpha).item() + if lora_name_alpha.startswith("lora_unet_"): + prefix = "unet." + elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")): + prefix = "text_encoder." + else: + prefix = "text_encoder_2." + new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha" + network_alphas.update({new_name: alpha}) + + if len(state_dict) > 0: + raise ValueError( + f"The following keys have not been correctly be renamed: \n\n {', '.join(state_dict.keys())}" ) - unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} - te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} + logger.info("Kohya-style checkpoint detected.") + unet_state_dict = {f"{cls.unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = { + f"{cls.text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items() + } + te2_state_dict = ( + {f"text_encoder_2.{module_name}": params for module_name, params in te2_state_dict.items()} + if len(te2_state_dict) > 0 + else None + ) + if te2_state_dict is not None: + te_state_dict.update(te2_state_dict) + new_state_dict = {**unet_state_dict, **te_state_dict} - return new_state_dict, network_alpha + return new_state_dict, network_alphas def unload_lora_weights(self): """ diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index de4adec042f6..43497c2284ac 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -521,17 +521,32 @@ class LoRAAttnProcessor(nn.Module): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size + + self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) + self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None @@ -1144,7 +1159,13 @@ class LoRAXFormersAttnProcessor(nn.Module): """ def __init__( - self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None + self, + hidden_size, + cross_attention_dim, + rank=4, + attention_op: Optional[Callable] = None, + network_alpha=None, + **kwargs, ): super().__init__() @@ -1153,10 +1174,25 @@ def __init__( self.rank = rank self.attention_op = attention_op - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size + + self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) + self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None @@ -1231,7 +1267,7 @@ class LoRAAttnProcessor2_0(nn.Module): Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. """ - def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -1240,10 +1276,25 @@ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha= self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + q_rank = kwargs.pop("q_rank", None) + q_hidden_size = kwargs.pop("q_hidden_size", None) + q_rank = q_rank if q_rank is not None else rank + q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size + + v_rank = kwargs.pop("v_rank", None) + v_hidden_size = kwargs.pop("v_hidden_size", None) + v_rank = v_rank if v_rank is not None else rank + v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size + + out_rank = kwargs.pop("out_rank", None) + out_hidden_size = kwargs.pop("out_hidden_size", None) + out_rank = out_rank if out_rank is not None else rank + out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size + + self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha) self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) + self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states diff --git a/src/diffusers/models/lora.py b/src/diffusers/models/lora.py index bb8389745776..171f1323cf84 100644 --- a/src/diffusers/models/lora.py +++ b/src/diffusers/models/lora.py @@ -14,6 +14,7 @@ from typing import Optional +import torch.nn.functional as F from torch import nn @@ -48,14 +49,19 @@ def forward(self, hidden_states): class LoRAConv2dLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4, network_alpha=None): + def __init__( + self, in_features, out_features, rank=4, kernel_size=(1, 1), stride=(1, 1), padding=0, network_alpha=None + ): super().__init__() if rank > min(in_features, out_features): raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") - self.down = nn.Conv2d(in_features, rank, (1, 1), (1, 1), bias=False) - self.up = nn.Conv2d(rank, out_features, (1, 1), (1, 1), bias=False) + self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) + # according to the official kohya_ss trainer kernel_size are always fixed for the up layer + # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 + self.up = nn.Conv2d(rank, out_features, kernel_size=(1, 1), stride=(1, 1), bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha @@ -91,7 +97,9 @@ def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]): def forward(self, x): if self.lora_layer is None: - return super().forward(x) + # make sure to the functional Conv2D function as otherwise torch.compile's graph will break + # see: https://github.com/huggingface/diffusers/pull/4315 + return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) else: return super().forward(x) + self.lora_layer(x) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 24c3b07e7cb6..72aa17ed2c2d 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -23,6 +23,7 @@ from .activations import get_activation from .attention import AdaGroupNorm from .attention_processor import SpatialNorm +from .lora import LoRACompatibleConv, LoRACompatibleLinear class Upsample1D(nn.Module): @@ -126,7 +127,7 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann if use_conv_transpose: conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) elif use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -196,7 +197,7 @@ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name= self.name = name if use_conv: - conv = nn.Conv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -534,13 +535,13 @@ def __init__( else: self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) if temb_channels is not None: if self.time_embedding_norm == "default": - self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels) + self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) elif self.time_embedding_norm == "scale_shift": - self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels) + self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": self.time_emb_proj = None else: @@ -557,7 +558,7 @@ def __init__( self.dropout = torch.nn.Dropout(dropout) conv_2d_out_channels = conv_2d_out_channels or out_channels - self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) + self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.nonlinearity = get_activation(non_linearity) @@ -583,7 +584,7 @@ def __init__( self.conv_shortcut = None if self.use_in_shortcut: - self.conv_shortcut = torch.nn.Conv2d( + self.conv_shortcut = LoRACompatibleConv( in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias ) diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index bbd93430da14..998535c58a73 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -23,7 +23,7 @@ from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock from .embeddings import PatchEmbed -from .lora import LoRACompatibleConv +from .lora import LoRACompatibleConv, LoRACompatibleLinear from .modeling_utils import ModelMixin @@ -137,7 +137,7 @@ def __init__( self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) else: self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: @@ -193,7 +193,7 @@ def __init__( if self.is_input_continuous: # TODO: should use out_channels for continuous projections if use_linear_projection: - self.proj_out = nn.Linear(inner_dim, in_channels) + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) else: self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) elif self.is_input_vectorized: diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index a22da6373181..2e8cee9ce697 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -1,7 +1,6 @@ from ..utils import ( OptionalDependencyNotAvailable, is_flax_available, - is_invisible_watermark_available, is_k_diffusion_available, is_librosa_available, is_note_seq_available, @@ -51,6 +50,7 @@ StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline, + StableDiffusionXLControlNetPipeline, ) from .deepfloyd_if import ( IFImg2ImgPipeline, @@ -108,6 +108,12 @@ StableUnCLIPPipeline, ) from .stable_diffusion_safe import StableDiffusionPipelineSafe + from .stable_diffusion_xl import ( + StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLInpaintPipeline, + StableDiffusionXLInstructPix2PixPipeline, + StableDiffusionXLPipeline, + ) from .t2i_adapter import StableDiffusionAdapterPipeline from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline, VideoToVideoSDPipeline from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline @@ -121,20 +127,6 @@ from .vq_diffusion import VQDiffusionPipeline -try: - if not (is_torch_available() and is_transformers_available() and is_invisible_watermark_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 -else: - from .controlnet import StableDiffusionXLControlNetPipeline - from .stable_diffusion_xl import ( - StableDiffusionXLImg2ImgPipeline, - StableDiffusionXLInpaintPipeline, - StableDiffusionXLInstructPix2PixPipeline, - StableDiffusionXLPipeline, - ) - try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py index 83a4b37f0441..d5f7eb6b4fcc 100644 --- a/src/diffusers/pipelines/controlnet/__init__.py +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -1,21 +1,11 @@ from ...utils import ( OptionalDependencyNotAvailable, is_flax_available, - is_invisible_watermark_available, is_torch_available, is_transformers_available, ) -try: - if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()): - raise OptionalDependencyNotAvailable() -except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 -else: - from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline - - try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() @@ -26,6 +16,7 @@ from .pipeline_controlnet import StableDiffusionControlNetPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline if is_transformers_available() and is_flax_available(): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9a743a40aa6a..dd0ffd82ca44 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -23,6 +23,8 @@ import PIL.Image from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer +from diffusers.utils.import_utils import is_invisible_watermark_available + from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel @@ -43,7 +45,11 @@ ) from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput -from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + + +if is_invisible_watermark_available(): + from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + from .multicontrolnet import MultiControlNetModel @@ -110,6 +116,7 @@ def __init__( controlnet: ControlNetModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() @@ -131,7 +138,13 @@ def __init__( self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) - self.watermark = StableDiffusionXLWatermarker() + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing @@ -293,7 +306,6 @@ def encode_prompt( text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids @@ -996,7 +1008,10 @@ def __call__( image = latents return StableDiffusionXLPipelineOutput(images=image) - image = self.watermark.apply_watermark(image) + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index ab9163102013..49595c7f7662 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -494,6 +494,7 @@ class DiffusionPipeline(ConfigMixin): _optional_components = [] _exclude_from_cpu_offload = [] _load_connected_pipes = False + _is_onnx = False def register_modules(self, **kwargs): # import it here to avoid circular import @@ -839,6 +840,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors weights. If set to `False`, safetensors weights are not loaded. + use_onnx (`bool`, *optional*, defaults to `None`): + If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights + will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is + `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending + with `.onnx` and `.pb`. kwargs (remaining dictionary of keyword arguments, *optional*): Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline class). The overwritten components are passed directly to the pipelines `__init__` method. See example @@ -1268,6 +1274,15 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + If set to `None`, the safetensors weights are downloaded if they're available **and** if the + safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors + weights. If set to `False`, safetensors weights are not loaded. + use_onnx (`bool`, *optional*, defaults to `False`): + If set to `True`, ONNX weights will always be downloaded if present. If set to `False`, ONNX weights + will never be downloaded. By default `use_onnx` defaults to the `_is_onnx` class attribute which is + `False` for non-ONNX pipelines and `True` for ONNX pipelines. ONNX weights include both files ending + with `.onnx` and `.pb`. Returns: `os.PathLike`: @@ -1293,6 +1308,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: custom_revision = kwargs.pop("custom_revision", None) variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) + use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) if use_safetensors and not is_safetensors_available(): @@ -1364,7 +1380,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: pretrained_model_name, use_auth_token, variant, revision, model_filenames ) - model_folder_names = {os.path.split(f)[0] for f in model_filenames} + model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names} # all filenames compatible with variant will be added allow_patterns = list(model_filenames) @@ -1411,6 +1427,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: ): ignore_patterns = ["*.bin", "*.msgpack"] + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")} safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")} if ( @@ -1423,6 +1443,10 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: else: ignore_patterns = ["*.safetensors", "*.msgpack"] + use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx + if not use_onnx: + ignore_patterns += ["*.onnx", "*.pb"] + bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")} bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")} if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames: @@ -1474,11 +1498,25 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: user_agent=user_agent, ) - if pipeline_class._load_connected_pipes: + # retrieve pipeline class from local file + cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) + pipeline_class = getattr(diffusers, cls_name, None) + + if pipeline_class is not None and pipeline_class._load_connected_pipes: modelcard = ModelCard.load(os.path.join(cached_folder, "README.md")) connected_pipes = sum([getattr(modelcard.data, k, []) for k in CONNECTED_PIPES_KEYS], []) for connected_pipe_repo_id in connected_pipes: - DiffusionPipeline.download(connected_pipe_repo_id) + download_kwargs = { + "cache_dir": cache_dir, + "resume_download": resume_download, + "force_download": force_download, + "proxies": proxies, + "local_files_only": local_files_only, + "use_auth_token": use_auth_token, + "variant": variant, + "use_safetensors": use_safetensors, + } + DiffusionPipeline.download(connected_pipe_repo_id, **download_kwargs) return cached_folder diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index 29082beb9128..ce7e804dde63 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -442,7 +442,7 @@ def __call__( if do_classifier_free_guidance: uncond_tokens: List[str] if negative_prompt is None: - uncond_tokens = [""] + uncond_tokens = [""] * batch_size elif type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" @@ -471,7 +471,7 @@ def __call__( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) # For classifier free guidance, we need to do two forward passes. diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index fdbe1dfaeffb..07604d7c082f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1186,6 +1186,7 @@ def download_from_original_stable_diffusion_ckpt( StableDiffusionInpaintPipeline, StableDiffusionPipeline, StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) @@ -1542,7 +1543,7 @@ def download_from_original_stable_diffusion_ckpt( checkpoint, config_name, prefix="conditioner.embedders.1.model.", has_projection=True, **config_kwargs ) - pipe = pipeline_class( + pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py index eb02f6cb321c..6c8ff7fe78df 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py @@ -41,6 +41,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py index 293ed7d981b8..508085094b16 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py @@ -98,6 +98,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py index 0bb39c4b1c61..4856babce807 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py @@ -90,6 +90,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline): feature_extractor: CLIPImageProcessor _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py index 8ef7a781451c..a4b54b9724fb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py @@ -67,6 +67,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ _optional_components = ["safety_checker", "feature_extractor"] + _is_onnx = True vae_encoder: OnnxRuntimeModel vae_decoder: OnnxRuntimeModel diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py index 56681391aeeb..93e86def7a05 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py @@ -46,6 +46,8 @@ def preprocess(image): class OnnxStableDiffusionUpscalePipeline(StableDiffusionUpscalePipeline): + _is_onnx = True + def __init__( self, vae: OnnxRuntimeModel, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 4de7487c2e45..582bf6223d44 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -424,10 +424,13 @@ def check_inputs( # verify batch size of prompt and image are same if image is a list or tensor or numpy array if isinstance(image, list) or isinstance(image, torch.Tensor) or isinstance(image, np.ndarray): - if isinstance(prompt, str): + if prompt is not None and isinstance(prompt, str): batch_size = 1 - else: + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + if isinstance(image, list): image_batch_size = len(image) else: diff --git a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py index f2994f7d7d2d..02bd96cfc23c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/__init__.py @@ -7,7 +7,6 @@ from ...utils import ( BaseOutput, OptionalDependencyNotAvailable, - is_invisible_watermark_available, is_torch_available, is_transformers_available, ) @@ -28,10 +27,10 @@ class StableDiffusionXLPipelineOutput(BaseOutput): try: - if not (is_transformers_available() and is_torch_available() and is_invisible_watermark_available()): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ...utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index e1dbc6237a66..a61fd0e303d9 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -30,10 +30,24 @@ XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers +<<<<<<< HEAD from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +======= +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_invisible_watermark_available, + logging, + randn_tensor, + replace_example_docstring, +) +>>>>>>> ba43ce3476ffa649a6a14f0e13af07df27f1c66f from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput -from .watermark import StableDiffusionXLWatermarker + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -79,11 +93,11 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] as well as the following saving methods: - - *LoRA*: [`loaders.LoraLoaderMixin.save_lora_weights`] + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] Args: vae ([`AutoencoderKL`]): @@ -120,6 +134,7 @@ def __init__( unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() @@ -137,7 +152,12 @@ def __init__( self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size - self.watermark = StableDiffusionXLWatermarker() + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -295,7 +315,6 @@ def encode_prompt( text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids @@ -834,7 +853,10 @@ def __call__( image = latents return StableDiffusionXLPipelineOutput(images=image) - image = self.watermark.apply_watermark(image) + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU @@ -848,14 +870,21 @@ def __call__( # Overrride to properly handle the loading and unloading of the additional text encoder. def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): - state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) - self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet) + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} if len(text_encoder_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_state_dict, - network_alpha=network_alpha, + network_alphas=network_alphas, text_encoder=self.text_encoder, prefix="text_encoder", lora_scale=self.lora_scale, @@ -865,7 +894,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di if len(text_encoder_2_state_dict) > 0: self.load_lora_into_text_encoder( text_encoder_2_state_dict, - network_alpha=network_alpha, + network_alphas=network_alphas, text_encoder=self.text_encoder_2, prefix="text_encoder_2", lora_scale=self.lora_scale, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 011a939672fd..e69e4bc74d43 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -33,13 +33,17 @@ from ...utils import ( is_accelerate_available, is_accelerate_version, + is_invisible_watermark_available, logging, randn_tensor, replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput -from .watermark import StableDiffusionXLWatermarker + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -131,6 +135,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() @@ -148,7 +153,12 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.watermark = StableDiffusionXLWatermarker() + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -307,7 +317,6 @@ def encode_prompt( text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids @@ -906,15 +915,17 @@ def denoising_value_valid(dnv): negative_aesthetic_score, dtype=prompt_embeds.dtype, ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device) # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -988,7 +999,10 @@ def denoising_value_valid(dnv): image = latents return StableDiffusionXLPipelineOutput(images=image) - image = self.watermark.apply_watermark(image) + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 6ae94be8722d..8b96b558ec7c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -30,10 +30,20 @@ XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_invisible_watermark_available, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput -from .watermark import StableDiffusionXLWatermarker + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -265,6 +275,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() @@ -282,7 +293,12 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.watermark = StableDiffusionXLWatermarker() + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): @@ -442,7 +458,6 @@ def encode_prompt( text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids @@ -1168,15 +1183,17 @@ def denoising_value_valid(dnv): negative_aesthetic_score, dtype=prompt_embeds.dtype, ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device) # 11. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1264,6 +1281,10 @@ def denoising_value_valid(dnv): else: return StableDiffusionXLPipelineOutput(images=latents) + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index a78ce1f52147..eec5f840277a 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -34,12 +34,16 @@ deprecate, is_accelerate_available, is_accelerate_version, + is_invisible_watermark_available, logging, randn_tensor, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput -from .watermark import StableDiffusionXLWatermarker + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -109,6 +113,7 @@ def __init__( scheduler: KarrasDiffusionSchedulers, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() @@ -128,7 +133,12 @@ def __init__( self.vae.config.force_upcast = True # force the VAE to be in float32 mode, as it overflows in float16 - self.watermark = StableDiffusionXLWatermarker() + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None def enable_vae_slicing(self): r""" @@ -811,6 +821,7 @@ def __call__( negative_aesthetic_score, dtype=prompt_embeds.dtype, ) + add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) original_prompt_embeds_len = len(prompt_embeds) original_add_text_embeds_len = len(add_text_embeds) @@ -819,6 +830,7 @@ def __call__( if do_classifier_free_guidance: prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0) + add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0) # Make dimensions consistent @@ -828,7 +840,7 @@ def __call__( prompt_embeds = prompt_embeds.to(device).to(torch.float32) add_text_embeds = add_text_embeds.to(device).to(torch.float32) - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + add_time_ids = add_time_ids.to(device) # 11. Denoising loop self.unet = self.unet.to(torch.float32) @@ -906,7 +918,10 @@ def __call__( image = latents return StableDiffusionXLPipelineOutput(images=image) - image = self.watermark.apply_watermark(image) + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) # Offload last model to CPU diff --git a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py index 0c9a7e78568c..2e75d2e081b4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -1,7 +1,15 @@ import numpy as np import torch +<<<<<<< HEAD from imwatermark import WatermarkEncoder +======= +from ...utils import is_invisible_watermark_available + + +if is_invisible_watermark_available(): + from imwatermark import WatermarkEncoder +>>>>>>> ba43ce3476ffa649a6a14f0e13af07df27f1c66f # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d7516fa601e1..aa5db72a6771 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -266,7 +266,9 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + print("test") if self.config.use_karras_sigmas: + print('use karras sigmas') log_sigmas = np.log(sigmas) sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round() diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py deleted file mode 100644 index eae5528148a5..000000000000 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +++ /dev/null @@ -1,77 +0,0 @@ -# This file is autogenerated by the command `make fix-copies`, do not edit. -from ..utils import DummyObject, requires_backends - - -class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - -class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - -class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - -class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - -class StableDiffusionXLPipeline(metaclass=DummyObject): - _backends = ["torch", "transformers", "invisible_watermark"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 254b99e85c05..df8009dd0e27 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -827,6 +827,81 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableDiffusionXLPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableUnCLIPImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 8c751bc6bf07..000748312fca 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -563,10 +563,10 @@ def get_dummy_components(self): projection_dim=32, ) text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder) @@ -737,8 +737,7 @@ def test_a1111(self): ).images images = images[0, -3:, -3:, -1].flatten() - - expected = np.array([0.3636, 0.3708, 0.3694, 0.3679, 0.3829, 0.3677, 0.3692, 0.3688, 0.3292]) + expected = np.array([0.3725, 0.3767, 0.3761, 0.3796, 0.3827, 0.3763, 0.3831, 0.3809, 0.3392]) self.assertTrue(np.allclose(images, expected, atol=1e-4)) diff --git a/tests/others/test_dependencies.py b/tests/others/test_dependencies.py index 3436cf92d896..3bac611e3f4f 100644 --- a/tests/others/test_dependencies.py +++ b/tests/others/test_dependencies.py @@ -14,6 +14,7 @@ import inspect import unittest +from importlib import import_module class DependencyTester(unittest.TestCase): @@ -37,3 +38,13 @@ def test_backend_registration(self): elif backend == "invisible_watermark": backend = "invisible-watermark" assert backend in deps, f"{backend} is not in the deps table!" + + def test_pipeline_imports(self): + import diffusers + import diffusers.pipelines + + all_classes = inspect.getmembers(diffusers, inspect.isclass) + for cls_name, cls_module in all_classes: + if hasattr(diffusers.pipelines, cls_name): + pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3]) + _ = import_module(pipeline_folder_module, str(cls_name)) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py index 7100e5023a5d..1905af6c695f 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -210,6 +210,68 @@ def test_stable_diffusion_upscale_batch(self): image = output.images assert image.shape[0] == 2 + def test_stable_diffusion_upscale_prompt_embeds(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_upscale + low_res_scheduler = DDPMScheduler() + scheduler = DDIMScheduler(prediction_type="v_prediction") + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionUpscalePipeline( + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + max_noise_level=350, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + ) + + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + prompt_embeds = sd_pipe._encode_prompt(prompt, device, 1, False) + image_from_prompt_embeds = sd_pipe( + prompt_embeds=prompt_embeds, + image=[low_res_image], + generator=generator, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] + + image_slice = image[0, -3:, -3:, -1] + image_from_prompt_embeds_slice = image_from_prompt_embeds[0, -3:, -3:, -1] + + expected_height_width = low_res_image.size[0] * 4 + assert image.shape == (1, expected_height_width, expected_height_width, 3) + expected_slice = np.array([0.3113, 0.3910, 0.4272, 0.4859, 0.5061, 0.4652, 0.5362, 0.5715, 0.5661]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_prompt_embeds_slice.flatten() - expected_slice).max() < 1e-2 + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") def test_stable_diffusion_upscale_fp16(self): """Test that stable diffusion upscale works with fp16""" diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index ad6905cfa4e3..2d251a658658 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -100,10 +100,10 @@ def get_dummy_components(self): projection_dim=32, ) text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { "unet": unet, diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index f91b06d7503e..1e879151ac2f 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -64,7 +64,7 @@ def get_dummy_components(self, skip_first_text_encoder=False): addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + projection_class_embeddings_input_dim=72, # 5 * 8 + 32 cross_attention_dim=64 if not skip_first_text_encoder else 32, ) scheduler = EulerDiscreteScheduler( @@ -100,10 +100,10 @@ def get_dummy_components(self, skip_first_text_encoder=False): projection_dim=32, ) text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { "unet": unet, @@ -113,9 +113,18 @@ def get_dummy_components(self, skip_first_text_encoder=False): "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "requires_aesthetics_score": True, } return components + def test_components_function(self): + init_components = self.get_dummy_components() + init_components.pop("requires_aesthetics_score") + pipe = self.pipeline_class(**init_components) + + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = image / 2 + 0.5 @@ -147,7 +156,7 @@ def test_stable_diffusion_xl_img2img_euler(self): assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4656, 0.4840, 0.4439, 0.6698, 0.5574, 0.4524, 0.5799, 0.5943, 0.5165]) + expected_slice = np.array([0.4664, 0.4886, 0.4403, 0.6902, 0.5592, 0.4534, 0.5931, 0.5951, 0.5224]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -165,7 +174,7 @@ def test_stable_diffusion_xl_refiner(self): assert image.shape == (1, 32, 32, 3) - expected_slice = np.array([0.4676, 0.4865, 0.4335, 0.6715, 0.5578, 0.4497, 0.5847, 0.5967, 0.5198]) + expected_slice = np.array([0.4578, 0.4981, 0.4301, 0.6454, 0.5588, 0.4442, 0.5678, 0.5940, 0.5176]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index c8f4230992b5..05ce3f11973e 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -66,7 +66,7 @@ def get_dummy_components(self, skip_first_text_encoder=False): addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + projection_class_embeddings_input_dim=72, # 5 * 8 + 32 cross_attention_dim=64 if not skip_first_text_encoder else 32, ) scheduler = EulerDiscreteScheduler( @@ -102,10 +102,10 @@ def get_dummy_components(self, skip_first_text_encoder=False): projection_dim=32, ) text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { "unet": unet, @@ -115,6 +115,7 @@ def get_dummy_components(self, skip_first_text_encoder=False): "tokenizer": tokenizer if not skip_first_text_encoder else None, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, + "requires_aesthetics_score": True, } return components @@ -142,6 +143,14 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + def test_components_function(self): + init_components = self.get_dummy_components() + init_components.pop("requires_aesthetics_score") + pipe = self.pipeline_class(**init_components) + + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + def test_stable_diffusion_xl_inpaint_euler(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() @@ -155,7 +164,7 @@ def test_stable_diffusion_xl_inpaint_euler(self): assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.6965, 0.5584, 0.5693, 0.5739, 0.6092, 0.6620, 0.5902, 0.5612, 0.5319]) + expected_slice = np.array([0.8029, 0.5523, 0.5825, 0.6003, 0.6702, 0.7018, 0.6369, 0.5955, 0.5123]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @@ -250,10 +259,9 @@ def test_stable_diffusion_xl_refiner(self): image = sd_pipe(**inputs).images image_slice = image[0, -3:, -3:, -1] - print(torch.from_numpy(image_slice).flatten()) assert image.shape == (1, 64, 64, 3) - expected_slice = np.array([0.9106, 0.6563, 0.6766, 0.6537, 0.6709, 0.7367, 0.6537, 0.5937, 0.5418]) + expected_slice = np.array([0.7045, 0.4838, 0.5454, 0.6270, 0.6168, 0.6717, 0.6484, 0.5681, 0.4922]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py index c7667178ec75..bbb0fe698087 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py @@ -68,7 +68,7 @@ def get_dummy_components(self): addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + projection_class_embeddings_input_dim=72, # 5 * 8 + 32 cross_attention_dim=64, ) @@ -105,10 +105,10 @@ def get_dummy_components(self): projection_dim=32, ) text_encoder = CLIPTextModel(text_encoder_config) - tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) - tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") components = { "unet": unet, @@ -118,8 +118,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, - # "safety_checker": None, - # "feature_extractor": None, + "requires_aesthetics_score": True, } return components @@ -141,6 +140,14 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + def test_components_function(self): + init_components = self.get_dummy_components() + init_components.pop("requires_aesthetics_score") + pipe = self.pipeline_class(**init_components) + + self.assertTrue(hasattr(pipe, "components")) + self.assertTrue(set(pipe.components.keys()) == set(init_components.keys())) + def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index b06e2fe65b51..5ce2316c9b19 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -310,6 +310,49 @@ def test_download_bin_index(self): assert len([f for f in files if ".bin" in f]) == 8 assert not any(".safetensors" in f for f in files) + def test_download_no_openvino_by_default(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-stable-diffusion-open-vino", + cache_dir=tmpdirname, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # make sure that by default no openvino weights are downloaded + assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) + assert not any("openvino_" in f for f in files) + + def test_download_no_onnx_by_default(self): + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline", + cache_dir=tmpdirname, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # make sure that by default no onnx weights are downloaded + assert all((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) + assert not any((f.endswith(".onnx") or f.endswith(".pb")) for f in files) + + with tempfile.TemporaryDirectory() as tmpdirname: + tmpdirname = DiffusionPipeline.download( + "hf-internal-testing/tiny-random-OnnxStableDiffusionPipeline", + cache_dir=tmpdirname, + use_onnx=True, + ) + + all_root_files = [t[-1] for t in os.walk(os.path.join(tmpdirname))] + files = [item for sublist in all_root_files for item in sublist] + + # if `use_onnx` is specified make sure weights are downloaded + assert any((f.endswith(".json") or f.endswith(".bin") or f.endswith(".txt")) for f in files) + assert any((f.endswith(".onnx")) for f in files) + assert any((f.endswith(".pb")) for f in files) + def test_download_no_safety_checker(self): prompt = "hello" pipe = StableDiffusionPipeline.from_pretrained( @@ -374,7 +417,7 @@ def test_cached_files_are_used_when_no_internet(self): response_mock.json.return_value = {} # Download this model to make sure it's in the cache. - orig_pipe = StableDiffusionPipeline.from_pretrained( + orig_pipe = DiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None ) orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")} @@ -382,7 +425,7 @@ def test_cached_files_are_used_when_no_internet(self): # Under the mock environment we get a 500 error when trying to reach the model. with mock.patch("requests.request", return_value=response_mock): # Download this model to make sure it's in the cache. - pipe = StableDiffusionPipeline.from_pretrained( + pipe = DiffusionPipeline.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", safety_checker=None ) comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")} @@ -392,6 +435,42 @@ def test_cached_files_are_used_when_no_internet(self): if p1.data.ne(p2.data).sum() > 0: assert False, "Parameters not the same!" + def test_local_files_only_are_used_when_no_internet(self): + # A mock response for an HTTP head request to emulate server down + response_mock = mock.Mock() + response_mock.status_code = 500 + response_mock.headers = {} + response_mock.raise_for_status.side_effect = HTTPError + response_mock.json.return_value = {} + + # first check that with local files only the pipeline can only be used if cached + with self.assertRaises(FileNotFoundError): + with tempfile.TemporaryDirectory() as tmpdirname: + orig_pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", local_files_only=True, cache_dir=tmpdirname + ) + + # now download + orig_pipe = DiffusionPipeline.download("hf-internal-testing/tiny-stable-diffusion-torch") + + # make sure it can be loaded with local_files_only + orig_pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", local_files_only=True + ) + orig_comps = {k: v for k, v in orig_pipe.components.items() if hasattr(v, "parameters")} + + # Under the mock environment we get a 500 error when trying to connect to the internet. + # Make sure it works local_files_only only works here! + with mock.patch("requests.request", return_value=response_mock): + # Download this model to make sure it's in the cache. + pipe = DiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch") + comps = {k: v for k, v in pipe.components.items() if hasattr(v, "parameters")} + + for m1, m2 in zip(orig_comps.values(), comps.values()): + for p1, p2 in zip(m1.parameters(), m2.parameters()): + if p1.data.ne(p2.data).sum() > 0: + assert False, "Parameters not the same!" + def test_download_from_variant_folder(self): for safe_avail in [False, True]: import diffusers diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 8e95f4f46443..1c71e2a908bc 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -387,7 +387,7 @@ def _test_inference_batch_consistent( batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] # make last batch super long - batched_inputs[name][-1] = 2000 * "very long" + batched_inputs[name][-1] = 100 * "very long" # or else we have images else: batched_inputs[name] = batch_size * [value] @@ -462,7 +462,7 @@ def _test_inference_batch_single_identical( batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)] # make last batch super long - batched_inputs[name][-1] = 2000 * "very long" + batched_inputs[name][-1] = 100 * "very long" # or else we have images else: batched_inputs[name] = batch_size * [value] From 05a5fe8a7d932344dedfbabf1744e295212ae158 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Mon, 31 Jul 2023 22:35:13 +0000 Subject: [PATCH 09/18] cleanup --- controlnet_inpaint.py | 197 --------------- custom_inpaint_pipeline.py | 491 ------------------------------------- 2 files changed, 688 deletions(-) delete mode 100644 controlnet_inpaint.py delete mode 100644 custom_inpaint_pipeline.py diff --git a/controlnet_inpaint.py b/controlnet_inpaint.py deleted file mode 100644 index ec7e1c06189a..000000000000 --- a/controlnet_inpaint.py +++ /dev/null @@ -1,197 +0,0 @@ -# !pip install transformers accelerate -import os -from typing import List, Optional, Union - -import numpy as np -import torch - -import PIL -from contexttimer import Timer -from diffusers import ( - ControlNetModel, - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - StableDiffusionControlNetInpaintPipeline, - StableDiffusionImg2ImgPipeline, -) -from diffusers.utils import load_image -from mask_utils import create_gradient, expand_image, make_inpaint_condition - - -def controlnet_inpaint( - self, - prompt, - image: Union[str, PIL.Image], - mask: Union[str, PIL.Image], - output_type="pil", - pipe=None, - num_inference_steps: int = 25, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - seed=None, -): - """ - Takes image and mask (can be PIL images, a local path, or a URL) - """ - init_image = load_image(image) - mask_image = load_image(mask) - control_image = make_inpaint_condition(init_image, mask_image) - - with Timer() as t: - if pipe is None: - # pipe = sd.get_pipeline() - raise NotImplementedError("Need to pass in a pipeline for now") - t_load = t.elapsed - - if seed is not None: - generator = torch.Generator().manual_seed(seed) - - pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) - with Timer() as t: - images = pipe( - prompt, - num_inference_steps=num_inference_steps, - generator=generator, - guidance_scale=guidance_scale, - negative_prompt="deformed iris, deformed pupils, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck", - width=init_image.width, - height=init_image.height, - image=init_image, - mask_image=mask_image, - control_image=control_image, - num_images_per_prompt=4, - output_type=output_type, - ).images - t_inference = t.elapsed - - return { - "images": images, - "performance": { - "t_load": t_load, - "t_inference": t_inference, - } - } - -def controlnet_extend( - self, - prompt: Union[str, List[str]], - init_image: Union[torch.FloatTensor, PIL.Image.Image], - model_id: str = "realistic_vision", - seed: Optional[int] = None, - expand_offset_x: int = 0, - expand_offset_y: int = 0, - img2img_strength: float = 0.35, - img2img_steps: int = 15, - mask_offset: float = 40, - num_inference_steps: int = 25, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - output_type: Optional[str] = "pil", - # pipe -): - - with Timer() as t: - controlnet_inpaint_pipe = self.model_loader.get_pipeline(model_id, StableDiffusionControlNetInpaintPipeline) - t_load = t.elapsed - - with Timer() as t: - extended_image, mask_img = expand_image(init_image, expand_x=expand_offset_x, expand_y=expand_offset_y) - print("Image size after extending, " + str(extended_image.size)) - - blend_mask = create_gradient(mask_img, x=expand_offset_x, y=expand_offset_y, offset=mask_offset) - t_preprocess = t.elapsed - - inpaint_results = controlnet_inpaint( - prompt, - extended_image, - mask_img, - output_type="pil", - pipe=controlnet_inpaint_pipe, - seed=seed, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - negative_prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, - ) - inpainted_images = inpaint_results["images"] - performance = inpaint_results["performance"] - - - img2img_pipe = StableDiffusionImg2ImgPipeline( - vae=controlnet_inpaint_pipe.vae, - text_encoder=controlnet_inpaint_pipe.text_encoder, - tokenizer=controlnet_inpaint_pipe.tokenizer, - unet=controlnet_inpaint_pipe.unet, - scheduler=controlnet_inpaint_pipe.scheduler, - safety_checker=None, - feature_extractor=controlnet_inpaint_pipe.feature_extractor, - ) - img2img_pipe = img2img_pipe.to("cuda") - - - # uses masked img2img to homogenize the inpainted zone and the original image, making them look more natural and reduce border artefacts - generator = torch.Generator().manual_seed(seed) if seed is not None else None - with Timer() as t: - final_images = img2img_pipe( - prompt, - # negative_prompt=negative_prompt, - negative_prompt=None, - image=inpainted_images, - mask_image=blend_mask, - strength=img2img_strength, - num_inference_steps=img2img_steps, - generator=generator, - ).images - t_img2img = t.elapsed - - if "t_load" in performance: - performance["t_load"] = performance["t_load"] + t_load - performance["t_preprocess"] = t_preprocess - performance["t_img2img"] = t_img2img - performance["t_inpaint"] = performance["t_inference"] - performance["t_inference"] = sum([performance["t_inference"], t_img2img]) - - for i, image in enumerate(final_images): - image.save("final_image_" + str(i) + ".png") - return { - "images": final_images, - "performance": performance, - } - - -if __name__ == "__main__": - init_image = load_image( - "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png" - ) - init_image = init_image.resize((512, 512)) - - # generator = torch.Generator(device="cpu").manual_seed(1) - - mask_image = load_image( - "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png" - ) - mask_image = mask_image.resize((512, 512)) - - img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" - mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" - - init_image = load_image(img_url) - mask_image = load_image(mask_url) - - img_path = "/home/erwann/diffusers/examples/community/new_image.png" - # mask_path = "/home/erwann/diffusers/examples/community/hard_mask_5.png" - mask_path = "/home/erwann/diffusers/examples/community/mask_image.png" - init_image = load_image(img_path) - mask_image = load_image(mask_path) - # mask_image.save("mask.png") - - # new_width = 480 - # new_height = new_width * init_image.height / init_image.width - # new_height = 640 - # init_image = init_image.resize((new_width, int(new_height))) - - # mask_image = mask_image.resize(init_image.size) - # mask_image = mask_image.resize((512, 512)) diff --git a/custom_inpaint_pipeline.py b/custom_inpaint_pipeline.py deleted file mode 100644 index 45652477db34..000000000000 --- a/custom_inpaint_pipeline.py +++ /dev/null @@ -1,491 +0,0 @@ -import inspect -import os -import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F - -import PIL.Image -from diffusers import StableDiffusionControlNetInpaintPipeline -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel -from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel -from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - is_accelerate_available, - is_accelerate_version, - is_compiled_module, - load_image, - logging, - randn_tensor, - replace_example_docstring, -) -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - - -def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): - """ - Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be - converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the - ``image`` and ``1`` for the ``mask``. - - The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be - binarized (``mask > 0.5``) and cast to ``torch.float32`` too. - - Args: - image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. - It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` - ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. - mask (_type_): The mask to apply to the image, i.e. regions to inpaint. - It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` - ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. - - - Raises: - ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask - should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. - TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not - (ot the other way around). - - Returns: - tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 - dimensions: ``batch x channels x height x width``. - """ - - if image is None: - raise ValueError("`image` input cannot be undefined.") - - if mask is None: - raise ValueError("`mask_image` input cannot be undefined.") - - if isinstance(image, torch.Tensor): - if not isinstance(mask, torch.Tensor): - raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not") - - # Batch single image - if image.ndim == 3: - assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)" - image = image.unsqueeze(0) - - # Batch and add channel dim for single mask - if mask.ndim == 2: - mask = mask.unsqueeze(0).unsqueeze(0) - - # Batch single mask or add channel dim - if mask.ndim == 3: - # Single batched mask, no channel dim or single mask not batched but channel dim - if mask.shape[0] == 1: - mask = mask.unsqueeze(0) - - # Batched masks no channel dim - else: - mask = mask.unsqueeze(1) - - assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions" - assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions" - assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size" - - # Check image is in [-1, 1] - if image.min() < -1 or image.max() > 1: - raise ValueError("Image should be in [-1, 1] range") - - # Check mask is in [0, 1] - if mask.min() < 0 or mask.max() > 1: - raise ValueError("Mask should be in [0, 1] range") - - # Binarize mask - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - - # Image as float32 - image = image.to(dtype=torch.float32) - elif isinstance(mask, torch.Tensor): - raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not") - else: - # preprocess image - if isinstance(image, (PIL.Image.Image, np.ndarray)): - image = [image] - if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): - # resize all images w.r.t passed height an width - image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image] - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - # n.b. ensure backwards compatibility as old function does not return image - if return_image: - return mask, masked_image, image - - return mask, masked_image - -class StableDiffusionMaskedLatentControlNetInpaintPipeline(StableDiffusionControlNetInpaintPipeline): - _optional_components = ["safety_checker", "feature_extractor"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - controlnet, - scheduler: KarrasDiffusionSchedulers, - feature_extractor: CLIPImageProcessor, - safety_checker=None, - requires_safety_checker: bool = False, - ): - super().__init__( - vae, - text_encoder, - tokenizer, - unet, - controlnet, - scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker = False, - ) - - def resize_mask( - self, mask, dtype=torch.float16, - ): - height = mask.height - width = mask.width - if isinstance(mask, (PIL.Image.Image, np.ndarray)): - mask = [mask] - if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): - # mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask] - mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask = torch.from_numpy(mask) - - mask = torch.nn.functional.interpolate( - mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) - ) - mask = mask.to(device="cuda", dtype=dtype) - print("mask unique values" , torch.unique(mask)) - return mask - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - image: Union[torch.Tensor, PIL.Image.Image] = None, - mask_image: Union[torch.Tensor, PIL.Image.Image] = None, - soft_mask=None, - control_image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, - height: Optional[int] = None, - width: Optional[int] = None, - strength: float = 1.0, - num_inference_steps: int = 50, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_conditioning_scale: Union[float, List[float]] = 0.5, - guess_mode: bool = False, - ): - height, width = self._default_height_width(height, width, image) - - if soft_mask is not None: - soft_mask_pil = load_image(soft_mask) - soft_mask = self.resize_mask(soft_mask_pil,) - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - control_image, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - controlnet_conditioning_scale, - ) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): - controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) - - global_pool_conditions = ( - controlnet.config.global_pool_conditions - if isinstance(controlnet, ControlNetModel) - else controlnet.nets[0].config.global_pool_conditions - ) - guess_mode = guess_mode or global_pool_conditions - - # 3. Encode input prompt - text_encoder_lora_scale = ( - cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None - ) - prompt_embeds = self._encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=text_encoder_lora_scale, - ) - - # 4. Prepare image - if isinstance(controlnet, ControlNetModel): - control_image = self.prepare_control_image( - image=control_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - elif isinstance(controlnet, MultiControlNetModel): - control_images = [] - - for control_image_ in control_image: - control_image_ = self.prepare_control_image( - image=control_image_, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - - control_images.append(control_image_) - - control_image = control_images - else: - assert False - - # 4. Preprocess mask and image - resizes image and mask w.r.t height and width - mask, masked_image, init_image = prepare_mask_and_masked_image( - image, mask_image, height, width, return_image=True - ) - - # 5. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps( - num_inference_steps=num_inference_steps, strength=strength, device=device - ) - # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) - latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise - is_strength_max = strength == 1.0 - - # 6. Prepare latent variables - num_channels_latents = self.vae.config.latent_channels - num_channels_unet = self.unet.config.in_channels - return_image_latents = num_channels_unet == 4 - latents_outputs = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - image=init_image, - timestep=latent_timestep, - is_strength_max=is_strength_max, - return_noise=True, - return_image_latents=return_image_latents, - ) - - if return_image_latents: - latents, noise, image_latents = latents_outputs - else: - latents, noise = latents_outputs - - # 7. Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask, - masked_image, - batch_size * num_images_per_prompt, - height, - width, - prompt_embeds.dtype, - device, - generator, - do_classifier_free_guidance, - ) - - # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: - # Infer ControlNet only for the conditional batch. - control_model_input = latents - control_model_input = self.scheduler.scale_model_input(control_model_input, t) - controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] - else: - control_model_input = latent_model_input - controlnet_prompt_embeds = prompt_embeds - - down_block_res_samples, mid_block_res_sample = self.controlnet( - control_model_input, - t, - encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image, - conditioning_scale=controlnet_conditioning_scale, - guess_mode=guess_mode, - return_dict=False, - ) - - if guess_mode and do_classifier_free_guidance: - # Infered ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # add 0 to the unconditional batch to keep it unchanged. - down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] - mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) - - # predict the noise residual - if num_channels_unet == 9: - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - return_dict=False, - )[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if num_channels_unet == 4: - init_latents_proper = image_latents[:1] - init_mask = mask[:1] - from PIL import Image - - - if i < len(timesteps) - 1: - noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.add_noise( - init_latents_proper, noise, torch.tensor([noise_timestep]) - ) - - if soft_mask is None: - latents = (1 - init_mask) * init_latents_proper + init_mask * latents - else: - print("Using soft mask in controlnet inpaint") - latents = (1 - soft_mask) * init_latents_proper + soft_mask * latents - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - callback(i, t, latents) - - # If we do sequential model offloading, let's offload unet and controlnet - # manually for max memory savings - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.unet.to("cpu") - self.controlnet.to("cpu") - torch.cuda.empty_cache() - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload last model to CPU - if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: - self.final_offload_hook.offload() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file From 95464c8f2635a7a4987cd1d4fade9ee35d4b4ae5 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Tue, 1 Aug 2023 18:23:42 +0000 Subject: [PATCH 10/18] clean conflifcts --- .../pipeline_stable_diffusion_xl.py | 11 ----------- .../pipelines/stable_diffusion_xl/watermark.py | 4 ---- 2 files changed, 15 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a61fd0e303d9..fa0e701a61d4 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -30,18 +30,7 @@ XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers -<<<<<<< HEAD from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring -======= -from ...utils import ( - is_accelerate_available, - is_accelerate_version, - is_invisible_watermark_available, - logging, - randn_tensor, - replace_example_docstring, -) ->>>>>>> ba43ce3476ffa649a6a14f0e13af07df27f1c66f from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput diff --git a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py index 2e75d2e081b4..0ec7a1bc3653 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/watermark.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/watermark.py @@ -1,15 +1,11 @@ import numpy as np import torch -<<<<<<< HEAD -from imwatermark import WatermarkEncoder -======= from ...utils import is_invisible_watermark_available if is_invisible_watermark_available(): from imwatermark import WatermarkEncoder ->>>>>>> ba43ce3476ffa649a6a14f0e13af07df27f1c66f # Copied from https://github.com/Stability-AI/generative-models/blob/613af104c6b85184091d42d374fef420eddb356d/scripts/demo/streamlit_helpers.py#L66 From 70550c0b77468f791ffb9f987823be4fbc3920fb Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Tue, 1 Aug 2023 22:06:32 +0000 Subject: [PATCH 11/18] fix watermark --- src/diffusers/pipelines/pipeline_utils.py | 6 +++--- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 9 ++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 49595c7f7662..9245a95b803c 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -26,15 +26,15 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import PIL import torch + +import diffusers +import PIL from huggingface_hub import ModelCard, hf_hub_download, model_info, snapshot_download from packaging import version from requests.exceptions import HTTPError from tqdm.auto import tqdm -import diffusers - from .. import __version__ from ..configuration_utils import ConfigMixin from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index fa0e701a61d4..534def0a5400 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -30,7 +30,14 @@ XFormersAttnProcessor, ) from ...schedulers import KarrasDiffusionSchedulers -from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_invisible_watermark_available, + logging, + randn_tensor, + replace_example_docstring, +) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput From aa8eb9d264a998ad21a9f95f8274b4b229d66135 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Sun, 6 Aug 2023 08:11:32 +0000 Subject: [PATCH 12/18] add support for early stopping cfg --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 9 +++++++++ .../pipeline_stable_diffusion_xl_img2img.py | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 534def0a5400..eb79bea4801c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -560,6 +560,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, + end_cfg=0.6, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -832,6 +833,14 @@ def __call__( # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if end_cfg is not None and i / num_inference_steps > end_cfg and do_classifier_free_guidance: + print("ENDING 2 CFG") + do_classifier_free_guidance = False + prompt_embeds = prompt_embeds[-1:] + add_text_embeds = add_text_embeds[-1:] + add_time_ids = add_time_ids[-1:] + + # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index e69e4bc74d43..13f331483f25 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -16,8 +16,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -import PIL.Image import torch + +import PIL.Image from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor From 87127737cce1052cc88509a1e058d04d4d08e8ba Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Thu, 10 Aug 2023 06:05:01 +0000 Subject: [PATCH 13/18] fix; kand pipelien undo fuckup --- .../pipeline_kandinsky2_2_combined.py | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py index 977a82fdbc9f..4f8626a9bdba 100644 --- a/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +++ b/src/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py @@ -177,6 +177,9 @@ def __init__( movq=movq, ) + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -187,6 +190,17 @@ def enable_model_cpu_offload(self, gpu_id=0): self.prior_pipe.enable_model_cpu_offload() self.decoder_pipe.enable_model_cpu_offload() + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) @@ -378,6 +392,9 @@ def __init__( movq=movq, ) + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -388,6 +405,17 @@ def enable_model_cpu_offload(self, gpu_id=0): self.prior_pipe.enable_model_cpu_offload() self.decoder_pipe.enable_model_cpu_offload() + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) @@ -427,7 +455,7 @@ def __call__( The prompt or prompts to guide the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored @@ -601,6 +629,9 @@ def __init__( movq=movq, ) + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): + self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op) + def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -611,6 +642,17 @@ def enable_model_cpu_offload(self, gpu_id=0): self.prior_pipe.enable_model_cpu_offload() self.decoder_pipe.enable_model_cpu_offload() + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id) + def progress_bar(self, iterable=None, total=None): self.prior_pipe.progress_bar(iterable=iterable, total=total) self.decoder_pipe.progress_bar(iterable=iterable, total=total) @@ -650,7 +692,7 @@ def __call__( The prompt or prompts to guide the image generation. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, or tensor representing an image batch, that will be used as the starting point for the - process. Can also accpet image latents as `image`, if passing latents directly, it will not be encoded + process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. mask_image (`np.array`): Tensor representing an image batch, to mask `image`. White pixels in the mask will be repainted, while From 67570616aa77343711bbb8b1e8890053268f1b7d Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Thu, 10 Aug 2023 22:52:07 +0000 Subject: [PATCH 14/18] feat: controlnet support, wip img2img --- .../controlnet/pipeline_controlnet_sd_xl.py | 134 ++++++++++++++++-- 1 file changed, 122 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index dd0ffd82ca44..3565e3e4c996 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -21,9 +21,8 @@ import torch.nn.functional as F import PIL.Image -from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer - from diffusers.utils.import_utils import is_invisible_watermark_available +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin @@ -180,6 +179,60 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() + def prepare_latents_img2img(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + if image is not None: + image = image.to(device=device, dtype=dtype) + + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + if latents is None: + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + else: + latents = latents.to(device) + latents = latents * self.scheduler.init_noise_sigma + + + # scale the initial noise by the standard deviation required by the scheduler + # NOTE: not sure if the following line is necessary in for img2img + # latents = latents * self.scheduler.init_noise_sigma + return latents + def enable_model_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared @@ -643,6 +696,15 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -681,6 +743,9 @@ def __call__( original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, + img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, + img2img_strength: float = 1.0, + **kwargs ): r""" Function invoked when calling the pipeline for generation. @@ -881,16 +946,61 @@ def __call__( # 6. Prepare latent variables num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) + + + if img2img_strength > 0.0: + img2img_image = self.prepare_image( + image=img2img_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + img2img_image = 2.0 * img2img_image - 1.0 + + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1 - img2img_strength, device) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + latents = self.prepare_latents_img2img( + img2img_image[0][None, :], + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + else: + timesteps = self.scheduler.timesteps + + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # latents = self.prepare_latents( + # batch_size * num_images_per_prompt, + # num_channels_latents, + # height, + # width, + # prompt_embeds.dtype, + # device, + # generator, + # latents, + # ) # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) From 405988353ecdfc71bf0322d33fd69d47c5a57e9c Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Fri, 11 Aug 2023 00:43:53 +0000 Subject: [PATCH 15/18] wip: img2img, sorta working --- .../controlnet/pipeline_controlnet_sd_xl.py | 140 ++++++++++-------- 1 file changed, 82 insertions(+), 58 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 3565e3e4c996..4f58e3c5fe2c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -179,57 +179,72 @@ def disable_vae_tiling(self): """ self.vae.disable_tiling() - def prepare_latents_img2img(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - if image is not None: - image = image.to(device=device, dtype=dtype) - - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: + def prepare_latents_img2img( + self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True + ): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" ) - if latents is None: - if isinstance(generator, list): + + # Offload text encoder if `enable_model_cpu_offload` was enabled + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.text_encoder_2.to("cpu") + torch.cuda.empty_cache() + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + if image.shape[1] == 4: + init_latents = image + + else: + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.config.force_upcast: + image = image.float() + self.vae.to(dtype=torch.float32) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + elif isinstance(generator, list): init_latents = [ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: init_latents = self.vae.encode(image).latent_dist.sample(generator) - + + if self.vae.config.force_upcast: + self.vae.to(dtype) + + init_latents = init_latents.to(dtype) init_latents = self.vae.config.scaling_factor * init_latents - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: - # expand init_latents for batch_size - deprecation_message = ( - f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" - " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" - " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" - " your script to pass as many initial images as text prompts to suppress this warning." - ) - # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) - additional_image_per_prompt = batch_size // init_latents.shape[0] - init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) - elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) - else: - init_latents = torch.cat([init_latents], dim=0) + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + if add_noise: shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - # get latents init_latents = self.scheduler.add_noise(init_latents, noise, timestep) - latents = init_latents - else: - latents = latents.to(device) - latents = latents * self.scheduler.init_noise_sigma + latents = init_latents - # scale the initial noise by the standard deviation required by the scheduler - # NOTE: not sure if the following line is necessary in for img2img + print('mul sigma') # latents = latents * self.scheduler.init_noise_sigma return latents @@ -696,13 +711,28 @@ def upcast_vae(self): self.vae.decoder.conv_in.to(dtype) self.vae.decoder.mid_block.to(dtype) - def get_timesteps(self, num_inference_steps, strength, device): + def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 - t_start = max(num_inference_steps - init_timestep, 0) timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) + return torch.tensor(timesteps), len(timesteps) + return timesteps, num_inference_steps - t_start @torch.no_grad() @@ -745,6 +775,7 @@ def __call__( target_size: Tuple[int, int] = None, img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, img2img_strength: float = 1.0, + hack=False, **kwargs ): r""" @@ -948,34 +979,27 @@ def __call__( num_channels_latents = self.unet.config.in_channels - if img2img_strength > 0.0: - img2img_image = self.prepare_image( - image=img2img_image, - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=self.controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, - guess_mode=guess_mode, - ) - img2img_image = 2.0 * img2img_image - 1.0 + if img2img_strength > 0.0 or hack: + print("DFJlkaSDK:J:LKJDFLK:DSJF\n\n\n\n WARNING remove hack") + img2img_image = self.image_processor.preprocess(img2img_image) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1 - img2img_strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, 1 - img2img_strength, device, denoising_start=None + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + add_noise = True + # 6. Prepare latent variables latents = self.prepare_latents_img2img( - img2img_image[0][None, :], + img2img_image, latent_timestep, - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, + batch_size, + num_images_per_prompt, prompt_embeds.dtype, device, generator, - latents, + add_noise, ) else: timesteps = self.scheduler.timesteps From 2f8cd5fc6de4b8c8d1fcb11bf28f8a4bb58b0afe Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Fri, 11 Aug 2023 19:00:52 +0000 Subject: [PATCH 16/18] add controlnet_strength (end of controlnetet guidance) to sdxl pipeline --- .../controlnet/pipeline_controlnet.py | 2 +- .../controlnet/pipeline_controlnet_sd_xl.py | 50 ++++++++++++++----- 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 9e8b8ef7a29f..69718bda82f2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1082,7 +1082,7 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - if t <= controlnet_strength * 1000: + if t >= controlnet_strength * 1000: noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 4f58e3c5fe2c..a1f7575ef2af 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -770,11 +770,12 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + controlnet_strength: float = 1.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - img2img_strength: float = 1.0, + img2img_strength: float = 0.0, hack=False, **kwargs ): @@ -979,14 +980,16 @@ def __call__( num_channels_latents = self.unet.config.in_channels - if img2img_strength > 0.0 or hack: - print("DFJlkaSDK:J:LKJDFLK:DSJF\n\n\n\n WARNING remove hack") + + if (img2img_strength > 0.0) and latents is None: img2img_image = self.image_processor.preprocess(img2img_image) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, 1 - img2img_strength, device, denoising_start=None ) + print(f"timesteps = {timesteps}") latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + print(f"latent_timestep = {latent_timestep}") add_noise = True @@ -1001,6 +1004,14 @@ def __call__( generator, add_noise, ) + elif latents is not None: + print("using pased latetns") + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, 1 - img2img_strength, device, denoising_start=None + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + print(f"latent_timestep = {latent_timestep}") + latents = latents else: timesteps = self.scheduler.timesteps @@ -1058,6 +1069,7 @@ def __call__( # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + print("starting denoising loop with time steps: ", timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance @@ -1099,16 +1111,28 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + + if t >= controlnet_strength * 1000: + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + else: + print("DEBUG: not using controlnet at ts", t) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: From 27c29b333a42ecf19f62c3c702150a483342688e Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Fri, 11 Aug 2023 20:24:42 +0000 Subject: [PATCH 17/18] add cnet start / end, make them a percentage of num_inference_steps as opposed to using value of (potentially nonlinear) timesteps --- .../pipelines/controlnet/pipeline_controlnet.py | 8 ++++++-- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 69718bda82f2..c61680fe97f5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -829,7 +829,8 @@ def __call__( guess_mode: bool = False, img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, img2img_strength: float = 1.0, - controlnet_strength: float = 1.0, + controlnet_start: float = 1.0, + controlnet_end: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -1082,7 +1083,9 @@ def __call__( mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) # predict the noise residual - if t >= controlnet_strength * 1000: + # if t <= controlnet_start * 1000: + if i / len(timesteps) >= controlnet_start and i / len(timesteps) <= controlnet_end: + print("using controlnet", t) noise_pred = self.unet( latent_model_input, t, @@ -1092,6 +1095,7 @@ def __call__( mid_block_additional_residual=mid_block_res_sample, ).sample else: + print("not using controlnet", t) noise_pred = self.unet( latent_model_input, t, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index a1f7575ef2af..f9f61761f2a2 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -770,7 +770,9 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, - controlnet_strength: float = 1.0, + controlnet_end: float = 0.0, + controlnet_start: float = 1., + # controlnet_strength: float = None, #for legacy support with old sdxl original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -896,6 +898,8 @@ def __call__( control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ control_guidance_end ] + print(f"controlnet_end = {controlnet_end}") + print(f"controlnet_start = {controlnet_start}") # 1. Check inputs. Raise error if not correct self.check_inputs( @@ -1112,7 +1116,9 @@ def __call__( # predict the noise residual - if t >= controlnet_strength * 1000: + # if t <= controlnet_start * 1000 and t >= controlnet_end * 1000: + if i / len(timesteps) >= controlnet_start and i / len(timesteps) <= controlnet_end: + print("using cnet at timetsep" , t) noise_pred = self.unet( latent_model_input, t, From 0a3cbf6619430935ed0a85b1bd73b4b9b8e259e4 Mon Sep 17 00:00:00 2001 From: erwannmillon Date: Fri, 11 Aug 2023 21:36:22 +0000 Subject: [PATCH 18/18] change default cnet start/end vals for nonxl controlnet pipelien --- .gitignore | 10 +++++++++- .../pipelines/controlnet/pipeline_controlnet.py | 4 ++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 45602a1f547e..a9815f231c4c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,14 @@ __pycache__/ *.so # tests and logs +**.jpg +**.jpeg +*.jpeg +**/*.png +**.png +*.png +**/*.png +**/*.jpeg tests/fixtures/cached_*_text.txt logs/ lightning_logs/ @@ -173,4 +181,4 @@ tags # ruff .ruff_cache -wandb \ No newline at end of file +wandb diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index c61680fe97f5..4c71a7858e31 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -829,8 +829,8 @@ def __call__( guess_mode: bool = False, img2img_image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, img2img_strength: float = 1.0, - controlnet_start: float = 1.0, - controlnet_end: float = 0.0, + controlnet_start: float = 0.0, + controlnet_end: float = 1.0, ): r""" Function invoked when calling the pipeline for generation.