From 319124847216a57a6ae12b567689aa72b28f1c02 Mon Sep 17 00:00:00 2001 From: Daniel Regado <35548192+guiyrt@users.noreply.github.com> Date: Fri, 20 Dec 2024 00:48:18 +0000 Subject: [PATCH] [WIP] SD3.5 IP-Adapter Pipeline Integration (#9987) * Added support for single IPAdapter on SD3.5 pipeline --------- Co-authored-by: hlky Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: YiYi Xu --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/attnprocessor.md | 2 + docs/source/en/api/loaders/ip_adapter.md | 6 + docs/source/en/api/loaders/transformer_sd3.md | 29 ++ .../stable_diffusion/stable_diffusion_3.md | 69 ++++- src/diffusers/loaders/__init__.py | 12 +- src/diffusers/loaders/ip_adapter.py | 251 +++++++++++++++++- src/diffusers/loaders/transformer_sd3.py | 89 +++++++ src/diffusers/models/attention_processor.py | 172 ++++++++++++ src/diffusers/models/embeddings.py | 181 +++++++++++++ .../models/transformers/transformer_sd3.py | 16 +- .../pipeline_stable_diffusion_3.py | 129 ++++++++- .../test_pipeline_stable_diffusion_3.py | 2 + 13 files changed, 935 insertions(+), 25 deletions(-) create mode 100644 docs/source/en/api/loaders/transformer_sd3.md create mode 100644 src/diffusers/loaders/transformer_sd3.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 27e9fe5e191b..6ac66db73026 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -238,6 +238,8 @@ title: Textual Inversion - local: api/loaders/unet title: UNet + - local: api/loaders/transformer_sd3 + title: SD3Transformer2D - local: api/loaders/peft title: PEFT title: Loaders diff --git a/docs/source/en/api/attnprocessor.md b/docs/source/en/api/attnprocessor.md index fee0d7e35764..8bdffc330567 100644 --- a/docs/source/en/api/attnprocessor.md +++ b/docs/source/en/api/attnprocessor.md @@ -86,6 +86,8 @@ An attention processor is a class for applying different types of attention mech [[autodoc]] models.attention_processor.IPAdapterAttnProcessor2_0 +[[autodoc]] models.attention_processor.SD3IPAdapterJointAttnProcessor2_0 + ## JointAttnProcessor2_0 [[autodoc]] models.attention_processor.JointAttnProcessor2_0 diff --git a/docs/source/en/api/loaders/ip_adapter.md b/docs/source/en/api/loaders/ip_adapter.md index a10f30ef8e5b..946a8b1af875 100644 --- a/docs/source/en/api/loaders/ip_adapter.md +++ b/docs/source/en/api/loaders/ip_adapter.md @@ -24,6 +24,12 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading] [[autodoc]] loaders.ip_adapter.IPAdapterMixin +## SD3IPAdapterMixin + +[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin + - all + - is_ip_adapter_active + ## IPAdapterMaskProcessor [[autodoc]] image_processor.IPAdapterMaskProcessor \ No newline at end of file diff --git a/docs/source/en/api/loaders/transformer_sd3.md b/docs/source/en/api/loaders/transformer_sd3.md new file mode 100644 index 000000000000..4fc9603054b4 --- /dev/null +++ b/docs/source/en/api/loaders/transformer_sd3.md @@ -0,0 +1,29 @@ + + +# SD3Transformer2D + +This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead. + +The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs. + + + +To learn more about how to load LoRA weights, see the [LoRA](../../using-diffusers/loading_adapters#lora) loading guide. + + + +## SD3Transformer2DLoadersMixin + +[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin + - all + - _load_ip_adapter_weights \ No newline at end of file diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md index 8170c5280d38..eb67964ab0bd 100644 --- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md +++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md @@ -59,9 +59,76 @@ image.save("sd3_hello_world.png") - [`stabilityai/stable-diffusion-3.5-large`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large) - [`stabilityai/stable-diffusion-3.5-large-turbo`](https://huggingface.co/stabilityai/stable-diffusion-3-5-large-turbo) +## Image Prompting with IP-Adapters + +An IP-Adapter lets you prompt SD3 with images, in addition to the text prompt. This is especially useful when describing complex concepts that are difficult to articulate through text alone and you have reference images. To load and use an IP-Adapter, you need: + +- `image_encoder`: Pre-trained vision model used to obtain image features, usually a CLIP image encoder. +- `feature_extractor`: Image processor that prepares the input image for the chosen `image_encoder`. +- `ip_adapter_id`: Checkpoint containing parameters of image cross attention layers and image projection. + +IP-Adapters are trained for a specific model architecture, so they also work in finetuned variations of the base model. You can use the [`~SD3IPAdapterMixin.set_ip_adapter_scale`] function to adjust how strongly the output aligns with the image prompt. The higher the value, the more closely the model follows the image prompt. A default value of 0.5 is typically a good balance, ensuring the model considers both the text and image prompts equally. + +```python +import torch +from PIL import Image + +from diffusers import StableDiffusion3Pipeline +from transformers import SiglipVisionModel, SiglipImageProcessor + +image_encoder_id = "google/siglip-so400m-patch14-384" +ip_adapter_id = "InstantX/SD3.5-Large-IP-Adapter" + +feature_extractor = SiglipImageProcessor.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +) +image_encoder = SiglipVisionModel.from_pretrained( + image_encoder_id, + torch_dtype=torch.float16 +).to( "cuda") + +pipe = StableDiffusion3Pipeline.from_pretrained( + "stabilityai/stable-diffusion-3.5-large", + torch_dtype=torch.float16, + feature_extractor=feature_extractor, + image_encoder=image_encoder, +).to("cuda") + +pipe.load_ip_adapter(ip_adapter_id) +pipe.set_ip_adapter_scale(0.6) + +ref_img = Image.open("image.jpg").convert('RGB') + +image = pipe( + width=1024, + height=1024, + prompt="a cat", + negative_prompt="lowres, low quality, worst quality", + num_inference_steps=24, + guidance_scale=5.0, + ip_adapter_image=ref_img +).images[0] + +image.save("result.jpg") +``` + +
+ +
IP-Adapter examples with prompt "a cat"
+
+ + + + +Check out [IP-Adapter](../../../using-diffusers/ip_adapter) to learn more about how IP-Adapters work. + + + + ## Memory Optimisations for SD3 -SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. +SD3 uses three text encoders, one of which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware. ### Running Inference with Model Offloading diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 6ea382d721de..c7ea0be55db2 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -56,6 +56,7 @@ def text_encoder_attn_modules(text_encoder): if is_torch_available(): _import_structure["single_file_model"] = ["FromOriginalModelMixin"] + _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["utils"] = ["AttnProcsLayers"] if is_transformers_available(): @@ -74,7 +75,10 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", ] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] - _import_structure["ip_adapter"] = ["IPAdapterMixin"] + _import_structure["ip_adapter"] = [ + "IPAdapterMixin", + "SD3IPAdapterMixin", + ] _import_structure["peft"] = ["PeftAdapterMixin"] @@ -82,11 +86,15 @@ def text_encoder_attn_modules(text_encoder): if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .single_file_model import FromOriginalModelMixin + from .transformer_sd3 import SD3Transformer2DLoadersMixin from .unet import UNet2DConditionLoadersMixin from .utils import AttnProcsLayers if is_transformers_available(): - from .ip_adapter import IPAdapterMixin + from .ip_adapter import ( + IPAdapterMixin, + SD3IPAdapterMixin, + ) from .lora_pipeline import ( AmusedLoraLoaderMixin, CogVideoXLoraLoaderMixin, diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ca460f948e6f..11ce4f1634d7 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -33,15 +33,18 @@ if is_transformers_available(): - from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel + +from ..models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, + IPAdapterXFormersAttnProcessor, + JointAttnProcessor2_0, + SD3IPAdapterJointAttnProcessor2_0, +) - from ..models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, - IPAdapterXFormersAttnProcessor, - ) logger = logging.get_logger(__name__) @@ -348,3 +351,235 @@ def unload_ip_adapter(self): else value.__class__() ) self.unet.set_attn_processor(attn_procs) + + +class SD3IPAdapterMixin: + """Mixin for handling StableDiffusion 3 IP Adapters.""" + + @property + def is_ip_adapter_active(self) -> bool: + """Checks if IP-Adapter is loaded and scale > 0. + + IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0, + the image context is irrelevant. + + Returns: + `bool`: True when IP-Adapter is loaded and any layer has scale > 0. + """ + scales = [ + attn_proc.scale + for attn_proc in self.transformer.attn_processors.values() + if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0) + ] + + return len(scales) > 0 and any(scale > 0 for scale in scales) + + @validate_hf_hub_args + def load_ip_adapter( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + weight_name: str = "ip-adapter.safetensors", + subfolder: Optional[str] = None, + image_encoder_folder: Optional[str] = "image_encoder", + **kwargs, + ) -> None: + """ + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`ModelMixin.save_pretrained`]. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + weight_name (`str`, defaults to "ip-adapter.safetensors"): + The name of the weight file to load. If a list is passed, it should have the same length as + `subfolder`. + subfolder (`str`, *optional*): + The subfolder location of a model file within a larger model repository on the Hub or locally. If a + list is passed, it should have the same length as `weight_name`. + image_encoder_folder (`str`, *optional*, defaults to `image_encoder`): + The subfolder location of the image encoder within a larger model repository on the Hub or locally. + Pass `None` to not load the image encoder. If the image encoder is located in a folder inside + `subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g. + `image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than + `subfolder`, you should pass the path to the folder that contains image encoder weights, for example, + `image_encoder_folder="different_subfolder/image_encoder"`. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # Load the main state dict first + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) + + if low_cpu_mem_usage and not is_accelerate_available(): + low_cpu_mem_usage = False + logger.warning( + "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" + " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" + " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" + " install accelerate\n```\n." + ) + + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): + raise NotImplementedError( + "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" + " `low_cpu_mem_usage=False`." + ) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = load_state_dict(model_file) + else: + state_dict = pretrained_model_name_or_path_or_dict + + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.") + + # Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None: + if image_encoder_folder is not None: + if not isinstance(pretrained_model_name_or_path_or_dict, dict): + logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}") + if image_encoder_folder.count("/") == 0: + image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix() + else: + image_encoder_subfolder = Path(image_encoder_folder).as_posix() + + # Commons args for loading image encoder and image processor + kwargs = { + "low_cpu_mem_usage": low_cpu_mem_usage, + "cache_dir": cache_dir, + "local_files_only": local_files_only, + } + + self.register_modules( + feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + image_encoder=SiglipVisionModel.from_pretrained(image_encoder_subfolder, **kwargs).to( + self.device, dtype=self.dtype + ), + ) + else: + raise ValueError( + "`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict." + ) + else: + logger.warning( + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + ) + + # Load IP-Adapter into transformer + self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage) + + def set_ip_adapter_scale(self, scale: float) -> None: + """ + Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only + conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages + the model to produce more diverse images, but they may not be as aligned with the image prompt. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.set_ip_adapter_scale(0.6) + >>> ... + ``` + + Args: + scale (float): + IP-Adapter scale to be set. + + """ + for attn_processor in self.transformer.attn_processors.values(): + if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0): + attn_processor.scale = scale + + def unload_ip_adapter(self) -> None: + """ + Unloads the IP Adapter weights. + + Example: + + ```python + >>> # Assuming `pipeline` is already loaded with the IP Adapter weights. + >>> pipeline.unload_ip_adapter() + >>> ... + ``` + """ + # Remove image encoder + if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None: + self.image_encoder = None + self.register_to_config(image_encoder=None) + + # Remove feature extractor + if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None: + self.feature_extractor = None + self.register_to_config(feature_extractor=None) + + # Remove image projection + self.transformer.image_proj = None + + # Restore original attention processors layers + attn_procs = { + name: ( + JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__() + ) + for name, value in self.transformer.attn_processors.items() + } + self.transformer.set_attn_processor(attn_procs) diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py new file mode 100644 index 000000000000..435d1da06ca1 --- /dev/null +++ b/src/diffusers/loaders/transformer_sd3.py @@ -0,0 +1,89 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict + +from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0 +from ..models.embeddings import IPAdapterTimeImageProjection +from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta + + +class SD3Transformer2DLoadersMixin: + """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`.""" + + def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None: + """Sets IP-Adapter attention processors, image projection, and loads state_dict. + + Args: + state_dict (`Dict`): + State dict with keys "ip_adapter", which contains parameters for attention processors, and + "image_proj", which contains parameters for image projection net. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + """ + # IP-Adapter cross attention parameters + hidden_size = self.config.attention_head_dim * self.config.num_attention_heads + ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads + timesteps_emb_dim = state_dict["ip_adapter"]["0.norm_ip.linear.weight"].shape[1] + + # Dict where key is transformer layer index, value is attention processor's state dict + # ip_adapter state dict keys example: "0.norm_ip.linear.weight" + layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))} + for key, weights in state_dict["ip_adapter"].items(): + idx, name = key.split(".", maxsplit=1) + layer_state_dict[int(idx)][name] = weights + + # Create IP-Adapter attention processor + attn_procs = {} + for idx, name in enumerate(self.attn_processors.keys()): + attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0( + hidden_size=hidden_size, + ip_hidden_states_dim=ip_hidden_states_dim, + head_dim=self.config.attention_head_dim, + timesteps_emb_dim=timesteps_emb_dim, + ).to(self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True) + else: + load_model_dict_into_meta( + attn_procs[name], layer_state_dict[idx], device=self.device, dtype=self.dtype + ) + + self.set_attn_processor(attn_procs) + + # Image projetion parameters + embed_dim = state_dict["image_proj"]["proj_in.weight"].shape[1] + output_dim = state_dict["image_proj"]["proj_out.weight"].shape[0] + hidden_dim = state_dict["image_proj"]["proj_in.weight"].shape[0] + heads = state_dict["image_proj"]["layers.0.attn.to_q.weight"].shape[0] // 64 + num_queries = state_dict["image_proj"]["latents"].shape[1] + timestep_in_dim = state_dict["image_proj"]["time_embedding.linear_1.weight"].shape[1] + + # Image projection + self.image_proj = IPAdapterTimeImageProjection( + embed_dim=embed_dim, + output_dim=output_dim, + hidden_dim=hidden_dim, + heads=heads, + num_queries=num_queries, + timestep_in_dim=timestep_in_dim, + ).to(device=self.device, dtype=self.dtype) + + if not low_cpu_mem_usage: + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + else: + load_model_dict_into_meta(self.image_proj, state_dict["image_proj"], device=self.device, dtype=self.dtype) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 05cbaa40e693..ed0dd4f71d27 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -5243,6 +5243,177 @@ def __call__( return hidden_states +class SD3IPAdapterJointAttnProcessor2_0(torch.nn.Module): + """ + Attention processor for IP-Adapter used typically in processing the SD3-like self-attention projections, with + additional image-based information and timestep embeddings. + + Args: + hidden_size (`int`): + The number of hidden channels. + ip_hidden_states_dim (`int`): + The image feature dimension. + head_dim (`int`): + The number of head channels. + timesteps_emb_dim (`int`, defaults to 1280): + The number of input channels for timestep embedding. + scale (`float`, defaults to 0.5): + IP-Adapter scale. + """ + + def __init__( + self, + hidden_size: int, + ip_hidden_states_dim: int, + head_dim: int, + timesteps_emb_dim: int = 1280, + scale: float = 0.5, + ): + super().__init__() + + # To prevent circular import + from .normalization import AdaLayerNorm, RMSNorm + + self.norm_ip = AdaLayerNorm(timesteps_emb_dim, output_dim=ip_hidden_states_dim * 2, norm_eps=1e-6, chunk_dim=1) + self.to_k_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.to_v_ip = nn.Linear(ip_hidden_states_dim, hidden_size, bias=False) + self.norm_q = RMSNorm(head_dim, 1e-6) + self.norm_k = RMSNorm(head_dim, 1e-6) + self.norm_ip_k = RMSNorm(head_dim, 1e-6) + self.scale = scale + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + ip_hidden_states: torch.FloatTensor = None, + temb: torch.FloatTensor = None, + ) -> torch.FloatTensor: + """ + Perform the attention computation, integrating image features (if provided) and timestep embeddings. + + If `ip_hidden_states` is `None`, this is equivalent to using JointAttnProcessor2_0. + + Args: + attn (`Attention`): + Attention instance. + hidden_states (`torch.FloatTensor`): + Input `hidden_states`. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + The encoder hidden states. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask. + ip_hidden_states (`torch.FloatTensor`, *optional*): + Image embeddings. + temb (`torch.FloatTensor`, *optional*): + Timestep embeddings. + + Returns: + `torch.FloatTensor`: Output hidden states. + """ + residual = hidden_states + + batch_size = hidden_states.shape[0] + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + img_query = query + img_key = key + img_value = value + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) + + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # IP Adapter + if self.scale != 0 and ip_hidden_states is not None: + # Norm image features + norm_ip_hidden_states = self.norm_ip(ip_hidden_states, temb=temb) + + # To k and v + ip_key = self.to_k_ip(norm_ip_hidden_states) + ip_value = self.to_v_ip(norm_ip_hidden_states) + + # Reshape + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Norm + query = self.norm_q(img_query) + img_key = self.norm_k(img_key) + ip_key = self.norm_ip_k(ip_key) + + # cat img + key = torch.cat([img_key, ip_key], dim=2) + value = torch.cat([img_value, ip_value], dim=2) + + ip_hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + ip_hidden_states = ip_hidden_states.transpose(1, 2).view(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + ip_hidden_states * self.scale + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class PAGIdentitySelfAttnProcessor2_0: r""" Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0). @@ -5772,6 +5943,7 @@ def __call__( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor, + SD3IPAdapterJointAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0, PAGCFGIdentitySelfAttnProcessor2_0, LoRAAttnProcessor, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 69b3ee8466f4..f1b339e6180b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -2396,6 +2396,187 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: return out +class IPAdapterTimeImageProjectionBlock(nn.Module): + """Block for IPAdapterTimeImageProjection. + + Args: + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + """ + + def __init__( + self, + hidden_dim: int = 1280, + dim_head: int = 64, + heads: int = 20, + ffn_ratio: int = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(hidden_dim) + self.ln1 = nn.LayerNorm(hidden_dim) + self.attn = Attention( + query_dim=hidden_dim, + cross_attention_dim=hidden_dim, + dim_head=dim_head, + heads=heads, + bias=False, + out_bias=False, + ) + self.ff = FeedForward(hidden_dim, hidden_dim, activation_fn="gelu", mult=ffn_ratio, bias=False) + + # AdaLayerNorm + self.adaln_silu = nn.SiLU() + self.adaln_proj = nn.Linear(hidden_dim, 4 * hidden_dim) + self.adaln_norm = nn.LayerNorm(hidden_dim) + + # Set attention scale and fuse KV + self.attn.scale = 1 / math.sqrt(math.sqrt(dim_head)) + self.attn.fuse_projections() + self.attn.to_k = None + self.attn.to_v = None + + def forward(self, x: torch.Tensor, latents: torch.Tensor, timestep_emb: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + latents (`torch.Tensor`): + Latent features. + timestep_emb (`torch.Tensor`): + Timestep embedding. + + Returns: + `torch.Tensor`: Output latent features. + """ + + # Shift and scale for AdaLayerNorm + emb = self.adaln_proj(self.adaln_silu(timestep_emb)) + shift_msa, scale_msa, shift_mlp, scale_mlp = emb.chunk(4, dim=1) + + # Fused Attention + residual = latents + x = self.ln0(x) + latents = self.ln1(latents) * (1 + scale_msa[:, None]) + shift_msa[:, None] + + batch_size = latents.shape[0] + + query = self.attn.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + key, value = self.attn.to_kv(kv_input).chunk(2, dim=-1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.attn.heads + + query = query.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.attn.heads, head_dim).transpose(1, 2) + + weight = (query * self.attn.scale) @ (key * self.attn.scale).transpose(-2, -1) + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + latents = weight @ value + + latents = latents.transpose(1, 2).reshape(batch_size, -1, self.attn.heads * head_dim) + latents = self.attn.to_out[0](latents) + latents = self.attn.to_out[1](latents) + latents = latents + residual + + ## FeedForward + residual = latents + latents = self.adaln_norm(latents) * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + return self.ff(latents) + residual + + +# Modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +class IPAdapterTimeImageProjection(nn.Module): + """Resampler of SD3 IP-Adapter with timestep embedding. + + Args: + embed_dim (`int`, defaults to 1152): + The feature dimension. + output_dim (`int`, defaults to 2432): + The number of output channels. + hidden_dim (`int`, defaults to 1280): + The number of hidden channels. + depth (`int`, defaults to 4): + The number of blocks. + dim_head (`int`, defaults to 64): + The number of head channels. + heads (`int`, defaults to 20): + Parallel attention heads. + num_queries (`int`, defaults to 64): + The number of queries. + ffn_ratio (`int`, defaults to 4): + The expansion ratio of feedforward network hidden layer channels. + timestep_in_dim (`int`, defaults to 320): + The number of input channels for timestep embedding. + timestep_flip_sin_to_cos (`bool`, defaults to True): + Flip the timestep embedding order to `cos, sin` (if True) or `sin, cos` (if False). + timestep_freq_shift (`int`, defaults to 0): + Controls the timestep delta between frequencies between dimensions. + """ + + def __init__( + self, + embed_dim: int = 1152, + output_dim: int = 2432, + hidden_dim: int = 1280, + depth: int = 4, + dim_head: int = 64, + heads: int = 20, + num_queries: int = 64, + ffn_ratio: int = 4, + timestep_in_dim: int = 320, + timestep_flip_sin_to_cos: bool = True, + timestep_freq_shift: int = 0, + ) -> None: + super().__init__() + self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dim) / hidden_dim**0.5) + self.proj_in = nn.Linear(embed_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + self.layers = nn.ModuleList( + [IPAdapterTimeImageProjectionBlock(hidden_dim, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) + self.time_proj = Timesteps(timestep_in_dim, timestep_flip_sin_to_cos, timestep_freq_shift) + self.time_embedding = TimestepEmbedding(timestep_in_dim, hidden_dim, act_fn="silu") + + def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + x (`torch.Tensor`): + Image features. + timestep (`torch.Tensor`): + Timestep in denoising process. + Returns: + `Tuple`[`torch.Tensor`, `torch.Tensor`]: The pair (latents, timestep_emb). + """ + timestep_emb = self.time_proj(timestep).to(dtype=x.dtype) + timestep_emb = self.time_embedding(timestep_emb) + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + x = x + timestep_emb[:, None] + + for block in self.layers: + latents = block(x, latents, timestep_emb) + + latents = self.proj_out(latents) + latents = self.norm_out(latents) + + return latents, timestep_emb + + class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index 79c4069e9a37..415540ef7f6a 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin from ...models.attention import FeedForward, JointTransformerBlock from ...models.attention_processor import ( Attention, @@ -103,7 +103,9 @@ def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor): return hidden_states -class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SD3Transformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, SD3Transformer2DLoadersMixin +): """ The Transformer model introduced in Stable Diffusion 3. @@ -349,8 +351,8 @@ def forward( Input `hidden_states`. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected - from the embeddings of input conditions. + pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): + Embeddings projected from the embeddings of input conditions. timestep (`torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states (`list` of `torch.Tensor`): @@ -390,6 +392,12 @@ def forward( temb = self.time_text_embed(timestep, pooled_projections) encoder_hidden_states = self.context_embedder(encoder_hidden_states) + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states, ip_temb = self.image_proj(ip_adapter_image_embeds, timestep) + + joint_attention_kwargs.update(ip_hidden_states=ip_hidden_states, temb=ip_temb) + for index_block, block in enumerate(self.transformer_blocks): # Skip specified layers is_skip = True if skip_layers is not None and index_block in skip_layers else False diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 0a51dcbc1261..a53d786798ca 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1,4 +1,4 @@ -# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,14 +17,16 @@ import torch from transformers import ( + BaseImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer, + PreTrainedModel, T5EncoderModel, T5TokenizerFast, ) -from ...image_processor import VaeImageProcessor -from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -142,7 +144,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): +class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -174,10 +176,14 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle tokenizer_3 (`T5TokenizerFast`): Tokenizer of class [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + image_encoder (`PreTrainedModel`, *optional*): + Pre-trained Vision Model for IP Adapter. + feature_extractor (`BaseImageProcessor`, *optional*): + Image processor for IP Adapter. """ - model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" - _optional_components = [] + model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" + _optional_components = ["image_encoder", "feature_extractor"] _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] def __init__( @@ -191,6 +197,8 @@ def __init__( tokenizer_2: CLIPTokenizer, text_encoder_3: T5EncoderModel, tokenizer_3: T5TokenizerFast, + image_encoder: PreTrainedModel = None, + feature_extractor: BaseImageProcessor = None, ): super().__init__() @@ -204,6 +212,8 @@ def __init__( tokenizer_3=tokenizer_3, transformer=transformer, scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, ) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -683,6 +693,83 @@ def num_timesteps(self): def interrupt(self): return self._interrupt + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image + def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor: + """Encodes the given image into a feature representation using a pre-trained image encoder. + + Args: + image (`PipelineImageInput`): + Input image to be encoded. + device: (`torch.device`): + Torch device. + + Returns: + `torch.Tensor`: The encoded image feature representation. + """ + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=self.dtype) + + return self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + + # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + ) -> torch.Tensor: + """Prepares image embeddings for use in the IP-Adapter. + + Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed. + + Args: + ip_adapter_image (`PipelineImageInput`, *optional*): + The input image to extract features from for IP-Adapter. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Precomputed image embeddings. + device: (`torch.device`, *optional*): + Torch device. + num_images_per_prompt (`int`, defaults to 1): + Number of images that should be generated per prompt. + do_classifier_free_guidance (`bool`, defaults to True): + Whether to use classifier free guidance or not. + """ + device = device or self._execution_device + + if ip_adapter_image_embeds is not None: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2) + else: + single_image_embeds = ip_adapter_image_embeds + elif ip_adapter_image is not None: + single_image_embeds = self.encode_image(ip_adapter_image, device) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.zeros_like(single_image_embeds) + else: + raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.") + + image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0) + image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0) + + return image_embeds.to(device=device) + + def enable_sequential_cpu_offload(self, *args, **kwargs): + if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload: + logger.warning( + "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses " + "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling " + "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`." + ) + + super().enable_sequential_cpu_offload(*args, **kwargs) + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -705,6 +792,8 @@ def __call__( negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, joint_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -713,9 +802,9 @@ def __call__( callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 256, skip_guidance_layers: List[int] = None, - skip_layer_guidance_scale: int = 2.8, - skip_layer_guidance_stop: int = 0.2, - skip_layer_guidance_start: int = 0.01, + skip_layer_guidance_scale: float = 2.8, + skip_layer_guidance_stop: float = 0.2, + skip_layer_guidance_start: float = 0.01, mu: Optional[float] = None, ): r""" @@ -781,6 +870,11 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`torch.Tensor`, *optional*): + Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, + emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to + `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -938,7 +1032,22 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # 6. Denoising loop + # 6. Prepare image embeddings + if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: + ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + if self.joint_attention_kwargs is None: + self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} + else: + self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) + + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 07ce5487f256..a6f718ae4fbb 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -103,6 +103,8 @@ def get_dummy_components(self): "tokenizer_3": tokenizer_3, "transformer": transformer, "vae": vae, + "image_encoder": None, + "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0):