diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index e2228fdacf30..b8fbf8f54362 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -238,12 +238,13 @@ jobs: run_flax_tpu_tests: name: Nightly Flax TPU Tests - runs-on: docker-tpu + runs-on: + group: gcp-ct5lp-hightpu-8t if: github.event_name == 'schedule' container: image: diffusers/diffusers-flax-tpu - options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/ --privileged + options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults: run: shell: bash @@ -519,4 +520,4 @@ jobs: # if: always() # run: | # pip install slack_sdk tabulate -# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY \ No newline at end of file +# python utils/log_reports.py >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml index 2289d1b5cad1..055c282e7c1e 100644 --- a/.github/workflows/push_tests.yml +++ b/.github/workflows/push_tests.yml @@ -161,11 +161,11 @@ jobs: flax_tpu_tests: name: Flax TPU Tests - runs-on: docker-tpu + runs-on: + group: gcp-ct5lp-hightpu-8t container: image: diffusers/diffusers-flax-tpu - options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged - defaults: + options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache defaults: run: shell: bash steps: diff --git a/examples/community/README_community_scripts.md b/examples/community/README_community_scripts.md index eae50247c9e5..3c9ad0d89bb4 100644 --- a/examples/community/README_community_scripts.md +++ b/examples/community/README_community_scripts.md @@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks from diffusers.configuration_utils import register_to_config import torch -from typing import Any, Dict, Optional +from typing import Any, Dict, Tuple, Union + + +class SDPromptSchedulingCallback(PipelineCallback): + @register_to_config + def __init__( + self, + encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + cutoff_step_ratio=None, + cutoff_step_index=None, + ): + super().__init__( + cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index + ) + + tensor_inputs = ["prompt_embeds"] + + def callback_fn( + self, pipeline, step_index, timestep, callback_kwargs + ) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + if isinstance(self.config.encoded_prompt, tuple): + prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt + else: + prompt_embeds = self.config.encoded_prompt + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index + if cutoff_step_index is not None + else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + if pipeline.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( @@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( pipeline.safety_checker = None pipeline.requires_safety_checker = False +callback = MultiPipelineCallbacks( + [ + SDPromptSchedulingCallback( + encoded_prompt=pipeline.encode_prompt( + prompt=f"prompt {index}", + negative_prompt=f"negative prompt {index}", + device=pipeline._execution_device, + num_images_per_prompt=1, + # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran + do_classifier_free_guidance=True, + ), + cutoff_step_index=index, + ) for index in range(1, 20) + ] +) + +image = pipeline( + prompt="prompt" + negative_prompt="negative prompt", + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=["prompt_embeds"], +).images[0] +torch.cuda.empty_cache() +image.save('image.png') +``` -class SDPromptScheduleCallback(PipelineCallback): +```python +from diffusers import StableDiffusionXLPipeline +from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks +from diffusers.configuration_utils import register_to_config +import torch +from typing import Any, Dict, Tuple, Union + + +class SDXLPromptSchedulingCallback(PipelineCallback): @register_to_config def __init__( self, - prompt: str, - negative_prompt: Optional[str] = None, - num_images_per_prompt: int = 1, - cutoff_step_ratio=1.0, + encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + cutoff_step_ratio=None, cutoff_step_index=None, ): super().__init__( cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index ) - tensor_inputs = ["prompt_embeds"] + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] def callback_fn( self, pipeline, step_index, timestep, callback_kwargs ) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio cutoff_step_index = self.config.cutoff_step_index + if isinstance(self.config.encoded_prompt, tuple): + prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt + else: + prompt_embeds = self.config.encoded_prompt + if isinstance(self.config.add_text_embeds, tuple): + add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds + else: + add_text_embeds = self.config.add_text_embeds + if isinstance(self.config.add_time_ids, tuple): + add_time_ids, negative_add_time_ids = self.config.add_time_ids + else: + add_time_ids = self.config.add_time_ids # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio cutoff_step = ( @@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback): ) if step_index == cutoff_step: - prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( - prompt=self.config.prompt, - negative_prompt=self.config.negative_prompt, - device=pipeline._execution_device, - num_images_per_prompt=self.config.num_images_per_prompt, - do_classifier_free_guidance=pipeline.do_classifier_free_guidance, - ) if pipeline.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds]) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids]) callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids return callback_kwargs -callback = MultiPipelineCallbacks( - [ - SDPromptScheduleCallback( - prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", - negative_prompt="Deformed, ugly, bad anatomy", - cutoff_step_ratio=0.25, + +pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, +).to("cuda") + +callbacks = [] +for index in range(1, 20): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipeline.encode_prompt( + prompt=f"prompt {index}", + negative_prompt=f"prompt {index}", + device=pipeline._execution_device, + num_images_per_prompt=1, + # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran + do_classifier_free_guidance=True, + ) + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + add_time_ids = pipeline._get_add_time_ids( + (1024, 1024), + (0, 0), + (1024, 1024), + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + negative_add_time_ids = pipeline._get_add_time_ids( + (1024, 1024), + (0, 0), + (1024, 1024), + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + callbacks.append( + SDXLPromptSchedulingCallback( + encoded_prompt=(prompt_embeds, negative_prompt_embeds), + add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds), + add_time_ids=(add_time_ids, negative_add_time_ids), + cutoff_step_index=index, ) - ] -) + ) + + +callback = MultiPipelineCallbacks(callbacks) image = pipeline( - prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", - negative_prompt="Deformed, ugly, bad anatomy", + prompt="prompt", + negative_prompt="negative prompt", callback_on_step_end=callback, - callback_on_step_end_tensor_inputs=["prompt_embeds"], + callback_on_step_end_tensor_inputs=[ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ], ).images[0] -torch.cuda.empty_cache() -image.save('image.png') ``` diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index f09160c4571d..c8a87a426dc0 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -648,6 +648,8 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, eta: float = 1.0, + decay_eta: Optional[bool] = False, + eta_decay_power: Optional[float] = 1.0, strength: float = 1.0, start_timestep: float = 0, stop_timestep: float = 0.25, @@ -880,12 +882,9 @@ def __call__( v_t = -noise_pred v_t_cond = (y_0 - latents) / (1 - t_i) eta_t = eta if start_timestep <= i < stop_timestep else 0.0 - if start_timestep <= i < stop_timestep: - # controlled vector field - v_hat_t = v_t + eta * (v_t_cond - v_t) - - else: - v_hat_t = v_t + if decay_eta: + eta_t = eta_t * (1 - i / num_inference_steps) ** eta_decay_power # Decay eta over the loop + v_hat_t = v_t + eta_t * (v_t_cond - v_t) # SDE Eq: 17 from https://arxiv.org/pdf/2410.10792 latents = latents + v_hat_t * (sigmas[i] - sigmas[i + 1]) diff --git a/examples/flux-control/README.md b/examples/flux-control/README.md index 493334ac2c55..26ad9d06a2af 100644 --- a/examples/flux-control/README.md +++ b/examples/flux-control/README.md @@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \ --max_train_steps=5000 \ --validation_image="openpose.png" \ --validation_prompt="A couple, 4k photo, highly detailed" \ + --offload \ --seed="0" \ --push_to_hub ``` @@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \ --validation_steps=200 \ --validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \ --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \ + --offload \ --seed="0" \ --push_to_hub ``` diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index ebca634cb8ce..0c8e26d5b358 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -541,6 +541,11 @@ def parse_args(input_args=None): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoders to CPU when they are not used.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -999,8 +1004,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): control_latents = encode_images( batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype ) - # offload vae to CPU. - vae.cpu() + if args.offload: + # offload vae to CPU. + vae.cpu() # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -1064,7 +1070,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: prompt_embeds.zero_() pooled_prompt_embeds.zero_() - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") # Predict. model_pred = flux_transformer( diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 5b5345ba6783..e1b234c40e61 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -573,6 +573,11 @@ def parse_args(input_args=None): default=1.29, help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", ) + parser.add_argument( + "--offload", + action="store_true", + help="Whether to offload the VAE and the text encoders to CPU when they are not used.", + ) if input_args is not None: args = parser.parse_args(input_args) @@ -1140,8 +1145,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): control_latents = encode_images( batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype ) - # offload vae to CPU. - vae.cpu() + + if args.offload: + # offload vae to CPU. + vae.cpu() # Sample a random timestep for each image # for weighting schemes where we sample timesteps non-uniformly @@ -1205,7 +1212,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts: prompt_embeds.zero_() pooled_prompt_embeds.zero_() - text_encoding_pipeline = text_encoding_pipeline.to("cpu") + if args.offload: + text_encoding_pipeline = text_encoding_pipeline.to("cpu") # Predict. model_pred = flux_transformer( diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f4318fc3cd39..c1d4f0b46e15 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -18,7 +18,7 @@ from torch import nn from ..utils import deprecate -from ..utils.import_utils import is_torch_npu_available +from ..utils.import_utils import is_torch_npu_available, is_torch_version if is_torch_npu_available(): @@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) @@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return F.gelu(gate) def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: diff --git a/src/diffusers/models/attention_flax.py b/src/diffusers/models/attention_flax.py index 25ae5d0a5d63..246f3afaf57c 100644 --- a/src/diffusers/models/attention_flax.py +++ b/src/diffusers/models/attention_flax.py @@ -216,8 +216,8 @@ def __call__(self, hidden_states, context=None, deterministic=True): hidden_states = jax_memory_efficient_attention( query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 ) - hidden_states = hidden_states.transpose(1, 0, 2) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) else: # compute attentions if self.split_head_dim: diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index c558c40be375..ea86d669f392 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -15,7 +15,7 @@ SparseControlNetModel, SparseControlNetOutput, ) - from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel + from .controlnet_union import ControlNetUnionModel from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel from .multicontrolnet import MultiControlNetModel diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py index 076629200eac..fc80da76235b 100644 --- a/src/diffusers/models/controlnets/controlnet_union.py +++ b/src/diffusers/models/controlnets/controlnet_union.py @@ -11,14 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn from ...configuration_utils import ConfigMixin, register_to_config -from ...image_processor import PipelineImageInput from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import logging from ..attention_processor import ( @@ -40,76 +38,6 @@ from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module -@dataclass -class ControlNetUnionInput: - """ - The image input of [`ControlNetUnionModel`]: - - - 0: openpose - - 1: depth - - 2: hed/pidi/scribble/ted - - 3: canny/lineart/anime_lineart/mlsd - - 4: normal - - 5: segment - """ - - openpose: Optional[PipelineImageInput] = None - depth: Optional[PipelineImageInput] = None - hed: Optional[PipelineImageInput] = None - canny: Optional[PipelineImageInput] = None - normal: Optional[PipelineImageInput] = None - segment: Optional[PipelineImageInput] = None - - def __len__(self) -> int: - return len(vars(self)) - - def __iter__(self): - return iter(vars(self)) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - -@dataclass -class ControlNetUnionInputProMax: - """ - The image input of [`ControlNetUnionModel`]: - - - 0: openpose - - 1: depth - - 2: hed/pidi/scribble/ted - - 3: canny/lineart/anime_lineart/mlsd - - 4: normal - - 5: segment - - 6: tile - - 7: repaint - """ - - openpose: Optional[PipelineImageInput] = None - depth: Optional[PipelineImageInput] = None - hed: Optional[PipelineImageInput] = None - canny: Optional[PipelineImageInput] = None - normal: Optional[PipelineImageInput] = None - segment: Optional[PipelineImageInput] = None - tile: Optional[PipelineImageInput] = None - repaint: Optional[PipelineImageInput] = None - - def __len__(self) -> int: - return len(vars(self)) - - def __iter__(self): - return iter(vars(self)) - - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -680,8 +608,9 @@ def forward( sample: torch.Tensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax], + controlnet_cond: List[torch.Tensor], control_type: torch.Tensor, + control_type_idx: List[int], conditioning_scale: float = 1.0, class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, @@ -701,11 +630,13 @@ def forward( The number of timesteps to denoise an input. encoder_hidden_states (`torch.Tensor`): The encoder hidden states. - controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): + controlnet_cond (`List[torch.Tensor]`): The conditional input tensors. control_type (`torch.Tensor`): A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control type is used. + control_type_idx (`List[int]`): + The indices of `control_type`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. class_labels (`torch.Tensor`, *optional*, defaults to `None`): @@ -733,20 +664,6 @@ def forward( If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ - if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(controlnet_cond) != self.config.num_control_type: - if isinstance(controlnet_cond, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(controlnet_cond, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`." - ) - # check channel order channel_order = self.config.controlnet_conditioning_channel_order @@ -830,12 +747,10 @@ def forward( inputs = [] condition_list = [] - for idx, image_type in enumerate(controlnet_cond): - if controlnet_cond[image_type] is None: - continue - condition = self.controlnet_cond_embedding(controlnet_cond[image_type]) + for cond, control_idx in zip(controlnet_cond, control_type_idx): + condition = self.controlnet_cond_embedding(cond) feat_seq = torch.mean(condition, dim=(2, 3)) - feat_seq = feat_seq + self.task_embedding[idx] + feat_seq = feat_seq + self.task_embedding[control_idx] inputs.append(feat_seq.unsqueeze(1)) condition_list.append(condition) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 702e5b586d59..b423c17c1246 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed( temporal_size: int, spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, + device: Optional[torch.device] = None, + output_type: str = "np", +) -> torch.Tensor: + r""" + Creates 3D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension of inputs. It must be divisible by 16. + spatial_size (`int` or `Tuple[int, int]`): + The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both + spatial dimensions (height and width). + temporal_size (`int`): + The temporal dimension of postional embeddings (number of frames). + spatial_interpolation_scale (`float`, defaults to 1.0): + Scale factor for spatial grid interpolation. + temporal_interpolation_scale (`float`, defaults to 1.0): + Scale factor for temporal grid interpolation. + + Returns: + `torch.Tensor`: + The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], + embed_dim]`. + """ + if output_type == "np": + return _get_3d_sincos_pos_embed_np( + embed_dim=embed_dim, + spatial_size=spatial_size, + temporal_size=temporal_size, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + ) + if embed_dim % 4 != 0: + raise ValueError("`embed_dim` must be divisible by 4") + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + embed_dim_spatial = 3 * embed_dim // 4 + embed_dim_temporal = embed_dim // 4 + + # 1. Spatial + grid_h = torch.arange(spatial_size[1], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid_w = torch.arange(spatial_size[0], device=device, dtype=torch.float32) / spatial_interpolation_scale + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) + pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid, output_type="pt") + + # 2. Temporal + grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32) / temporal_interpolation_scale + pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t, output_type="pt") + + # 3. Concat + pos_embed_spatial = pos_embed_spatial[None, :, :] + pos_embed_spatial = pos_embed_spatial.repeat_interleave(temporal_size, dim=0) # [T, H*W, D // 4 * 3] + + pos_embed_temporal = pos_embed_temporal[:, None, :] + pos_embed_temporal = pos_embed_temporal.repeat_interleave( + spatial_size[0] * spatial_size[1], dim=1 + ) # [T, H*W, D // 4] + + pos_embed = torch.concat([pos_embed_temporal, pos_embed_spatial], dim=-1) # [T, H*W, D] + return pos_embed + + +def _get_3d_sincos_pos_embed_np( + embed_dim: int, + spatial_size: Union[int, Tuple[int, int]], + temporal_size: int, + spatial_interpolation_scale: float = 1.0, + temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" Creates 3D sinusoidal positional embeddings. @@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed( The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1], embed_dim]`. """ + deprecation_message = ( + "`get_3d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") if isinstance(spatial_size, int): @@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed( def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + interpolation_scale=1.0, + base_size=16, + device: Optional[torch.device] = None, + output_type: str = "np", +): + """ + Creates 2D sinusoidal positional embeddings. + + Args: + embed_dim (`int`): + The embedding dimension. + grid_size (`int`): + The size of the grid height and width. + cls_token (`bool`, defaults to `False`): + Whether or not to add a classification token. + extra_tokens (`int`, defaults to `0`): + The number of extra tokens to add. + interpolation_scale (`float`, defaults to `1.0`): + The scale of the interpolation. + + Returns: + pos_embed (`torch.Tensor`): + Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size, + embed_dim]` if using cls_token + """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_2d_sincos_pos_embed_np( + embed_dim=embed_dim, + grid_size=grid_size, + cls_token=cls_token, + extra_tokens=extra_tokens, + interpolation_scale=interpolation_scale, + base_size=base_size, + ) + if isinstance(grid_size, int): + grid_size = (grid_size, grid_size) + + grid_h = ( + torch.arange(grid_size[0], device=device, dtype=torch.float32) + / (grid_size[0] / base_size) + / interpolation_scale + ) + grid_w = ( + torch.arange(grid_size[1], device=device, dtype=torch.float32) + / (grid_size[1] / base_size) + / interpolation_scale + ) + grid = torch.meshgrid(grid_w, grid_h, indexing="xy") # here w goes first + grid = torch.stack(grid, dim=0) + + grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type=output_type) + if cls_token and extra_tokens > 0: + pos_embed = torch.concat([torch.zeros([extra_tokens, embed_dim]), pos_embed], dim=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"): + r""" + This function generates 2D sinusoidal positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension. + grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`. + + Returns: + `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)` + """ + if output_type == "np": + deprecation_message = ( + "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_2d_sincos_pos_embed_from_grid_np( + embed_dim=embed_dim, + grid=grid, + ) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], output_type=output_type) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], output_type=output_type) # (H*W, D/2) + + emb = torch.concat([emb_h, emb_w], dim=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"): + """ + This function generates 1D positional embeddings from a grid. + + Args: + embed_dim (`int`): The embedding dimension `D` + pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)` + + Returns: + `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`. + """ + if output_type == "np": + deprecation_message = ( + "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`." + " `from_numpy` is no longer required." + " Pass `output_type='pt' to use the new version now." + ) + deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False) + return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos) + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = torch.outer(pos, omega) # (M, D/2), outer product + + emb_sin = torch.sin(out) # (M, D/2) + emb_cos = torch.cos(out) # (M, D/2) + + emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D) + return emb + + +def get_2d_sincos_pos_embed_np( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ @@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed( grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid) if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): +def get_2d_sincos_pos_embed_from_grid_np(embed_dim, grid): r""" This function generates 2D sinusoidal positional embeddings from a grid. @@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb_h = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid_np(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): +def get_1d_sincos_pos_embed_from_grid_np(embed_dim, pos): """ This function generates 1D positional embeddings from a grid. @@ -288,10 +503,14 @@ def __init__( self.pos_embed = None elif pos_embed_type == "sincos": pos_embed = get_2d_sincos_pos_embed( - embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale + embed_dim, + grid_size, + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + output_type="pt", ) persistent = True if pos_embed_max_size else False - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=persistent) else: raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") @@ -341,8 +560,10 @@ def forward(self, latent): grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, + device=latent.device, + output_type="pt", ) - pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) + pos_embed = pos_embed.float().unsqueeze(0) else: pos_embed = self.pos_embed @@ -453,7 +674,9 @@ def __init__( pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) - def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor: + def _get_positional_embeddings( + self, sample_height: int, sample_width: int, sample_frames: int, device: Optional[torch.device] = None + ) -> torch.Tensor: post_patch_height = sample_height // self.patch_size post_patch_width = sample_width // self.patch_size post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 @@ -465,8 +688,10 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp post_time_compression_frames, self.spatial_interpolation_scale, self.temporal_interpolation_scale, + device=device, + output_type="pt", ) - pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1) + pos_embedding = pos_embedding.flatten(0, 1) joint_pos_embedding = torch.zeros( 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False ) @@ -521,8 +746,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): or self.sample_width != width or self.sample_frames != pre_time_compression_frames ): - pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames) - pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype) + pos_embedding = self._get_positional_embeddings( + height, width, pre_time_compression_frames, device=embeds.device + ) + pos_embedding = pos_embedding.to(dtype=embeds.dtype) else: pos_embedding = self.pos_embedding @@ -552,9 +779,11 @@ def __init__( # Linear projection for text embeddings self.text_proj = nn.Linear(text_hidden_size, hidden_size) - pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size) + pos_embed = get_2d_sincos_pos_embed( + hidden_size, pos_embed_max_size, base_size=pos_embed_max_size, output_type="pt" + ) pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False) + self.register_buffer("pos_embed", pos_embed.float(), persistent=False) def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor: batch_size, channel, height, width = hidden_states.shape diff --git a/src/diffusers/models/transformers/latte_transformer_3d.py b/src/diffusers/models/transformers/latte_transformer_3d.py index 7e2b1273687d..d34ccfd20108 100644 --- a/src/diffusers/models/transformers/latte_transformer_3d.py +++ b/src/diffusers/models/transformers/latte_transformer_3d.py @@ -156,9 +156,9 @@ def __init__( # define temporal positional embedding temp_pos_embed = get_1d_sincos_pos_embed_from_grid( - inner_dim, torch.arange(0, video_length).unsqueeze(1) + inner_dim, torch.arange(0, video_length).unsqueeze(1), output_type="pt" ) # 1152 hidden size - self.register_buffer("temp_pos_embed", torch.from_numpy(temp_pos_embed).float().unsqueeze(0), persistent=False) + self.register_buffer("temp_pos_embed", temp_pos_embed.float().unsqueeze(0), persistent=False) self.gradient_checkpointing = False diff --git a/src/diffusers/models/unets/unet_3d_blocks.py b/src/diffusers/models/unets/unet_3d_blocks.py index 9c9fd7555899..195f7601dd54 100644 --- a/src/diffusers/models/unets/unet_3d_blocks.py +++ b/src/diffusers/models/unets/unet_3d_blocks.py @@ -1375,6 +1375,7 @@ def forward( res_hidden_states_tuple: Tuple[torch.Tensor, ...], temb: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, ) -> torch.Tensor: for resnet in self.resnets: # pop res hidden states @@ -1415,7 +1416,7 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states @@ -1485,6 +1486,7 @@ def forward( temb: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, image_only_indicator: Optional[torch.Tensor] = None, + upsample_size: Optional[int] = None, ) -> torch.Tensor: for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states @@ -1533,6 +1535,6 @@ def custom_forward(*inputs): if self.upsamplers is not None: for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states) + hidden_states = upsampler(hidden_states, upsample_size) return hidden_states diff --git a/src/diffusers/models/unets/unet_spatio_temporal_condition.py b/src/diffusers/models/unets/unet_spatio_temporal_condition.py index 9fb975bc32d9..308b9e01c587 100644 --- a/src/diffusers/models/unets/unet_spatio_temporal_condition.py +++ b/src/diffusers/models/unets/unet_spatio_temporal_condition.py @@ -382,6 +382,20 @@ def forward( If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise a `tuple` is returned where the first element is the sample tensor. """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + # 1. time timesteps = timestep if not torch.is_tensor(timesteps): @@ -457,15 +471,23 @@ def forward( # 5. up for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, image_only_indicator=image_only_indicator, ) else: @@ -473,6 +495,7 @@ def forward( hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, image_only_indicator=image_only_indicator, ) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py index 0465391d7305..bfc28615e8b4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py @@ -40,7 +40,6 @@ AttnProcessor2_0, XFormersAttnProcessor, ) -from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -82,7 +81,6 @@ def retrieve_latents( Examples: ```py from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL - from diffusers.models.controlnets import ControlNetUnionInputProMax from diffusers.utils import load_image import torch import numpy as np @@ -114,11 +112,8 @@ def retrieve_latents( mask_np = np.array(mask) controlnet_img_np[mask_np > 0] = 0 controlnet_img = Image.fromarray(controlnet_img_np) - union_input = ControlNetUnionInputProMax( - repaint=controlnet_img, - ) # generate image - image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0] + image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0] image.save("inpaint.png") ``` """ @@ -1130,7 +1125,7 @@ def __call__( prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, - control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, @@ -1158,6 +1153,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), @@ -1345,20 +1341,6 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(control_image_list) != controlnet.config.num_control_type: - if isinstance(control_image_list, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(control_image_list, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`." - ) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] @@ -1375,36 +1357,44 @@ def __call__( elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + num_control_type = controlnet.config.num_control_type + # 1. Check inputs - control_type = [] - for image_type in control_image_list: - if control_image_list[image_type]: - self.check_inputs( - prompt, - prompt_2, - control_image_list[image_type], - mask_image, - strength, - num_inference_steps, - callback_steps, - output_type, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - padding_mask_crop, - ) - control_type.append(1) - else: - control_type.append(0) + control_type = [0 for _ in range(num_control_type)] + for _image, control_idx in zip(control_image, control_mode): + control_type[control_idx] = 1 + self.check_inputs( + prompt, + prompt_2, + _image, + mask_image, + strength, + num_inference_steps, + callback_steps, + output_type, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + padding_mask_crop, + ) control_type = torch.Tensor(control_type) @@ -1499,23 +1489,21 @@ def denoising_value_valid(dnv): init_image = init_image.to(dtype=torch.float32) # 5.2 Prepare control images - for image_type in control_image_list: - if control_image_list[image_type]: - control_image = self.prepare_control_image( - image=control_image_list[image_type], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - crops_coords=crops_coords, - resize_mode=resize_mode, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image.shape[-2:] - control_image_list[image_type] = control_image + for idx, _ in enumerate(control_image): + control_image[idx] = self.prepare_control_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + crops_coords=crops_coords, + resize_mode=resize_mode, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image[idx].shape[-2:] # 5.3 Prepare mask mask = self.mask_processor.preprocess( @@ -1589,6 +1577,9 @@ def denoising_value_valid(dnv): original_size = original_size or (height, width) target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds @@ -1693,8 +1684,9 @@ def denoising_value_valid(dnv): control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image_list, + controlnet_cond=control_image, control_type=control_type, + control_type_idx=control_mode, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py index 58a8ba62e24e..78395243f6e4 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py @@ -43,7 +43,6 @@ AttnProcessor2_0, XFormersAttnProcessor, ) -from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -70,7 +69,6 @@ >>> # !pip install controlnet_aux >>> from controlnet_aux import LineartAnimeDetector >>> from diffusers import StableDiffusionXLControlNetUnionPipeline, ControlNetUnionModel, AutoencoderKL - >>> from diffusers.models.controlnets import ControlNetUnionInput >>> from diffusers.utils import load_image >>> import torch @@ -89,17 +87,14 @@ ... controlnet=controlnet, ... vae=vae, ... torch_dtype=torch.float16, + ... variant="fp16", ... ) >>> pipe.enable_model_cpu_offload() >>> # prepare image >>> processor = LineartAnimeDetector.from_pretrained("lllyasviel/Annotators") >>> controlnet_img = processor(image, output_type="pil") - >>> # set ControlNetUnion input - >>> union_input = ControlNetUnionInput( - ... canny=controlnet_img, - ... ) >>> # generate image - >>> image = pipe(prompt, image=union_input).images[0] + >>> image = pipe(prompt, control_image=[controlnet_img], control_mode=[3], height=1024, width=1024).images[0] ``` """ @@ -791,26 +786,6 @@ def check_inputs( f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" ) - def check_input( - self, - image: Union[ControlNetUnionInput, ControlNetUnionInputProMax], - ): - controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - - if not isinstance(image, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(image) != controlnet.config.num_control_type: - if isinstance(image, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(image, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInput`." - ) - # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image def prepare_image( self, @@ -970,7 +945,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, @@ -997,6 +972,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1018,10 +994,7 @@ def __call__( prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in both text-encoders. - image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): - In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, - `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, - `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + control_image (`PipelineImageInput`): The ControlNet input condition to provide guidance to the `unet` for generation. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or @@ -1168,38 +1141,45 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - self.check_input(image) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + num_control_type = controlnet.config.num_control_type + + # 1. Check inputs + control_type = [0 for _ in range(num_control_type)] # 1. Check inputs. Raise error if not correct - control_type = [] - for image_type in image: - if image[image_type]: - self.check_inputs( - prompt, - prompt_2, - image[image_type], - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - negative_pooled_prompt_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - control_type.append(1) - else: - control_type.append(0) + for _image, control_idx in zip(control_image, control_mode): + control_type[control_idx] = 1 + self.check_inputs( + prompt, + prompt_2, + _image, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) control_type = torch.Tensor(control_type) @@ -1258,20 +1238,19 @@ def __call__( ) # 4. Prepare image - for image_type in image: - if image[image_type]: - image[image_type] = self.prepare_image( - image=image[image_type], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = image[image_type].shape[-2:] + for idx, _ in enumerate(control_image): + control_image[idx] = self.prepare_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image[idx].shape[-2:] # 5. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -1312,11 +1291,11 @@ def __call__( ) # 7.2 Prepare added time ids & embeddings - for image_type in image: - if isinstance(image[image_type], torch.Tensor): - original_size = original_size or image[image_type].shape[-2:] - + original_size = original_size or (height, width) target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) @@ -1424,8 +1403,9 @@ def __call__( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=image, + controlnet_cond=control_image, control_type=control_type, + control_type_idx=control_mode, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, @@ -1478,7 +1458,6 @@ def __call__( ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) - image = callback_outputs.pop("image", image) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py index a3002eb565ff..f36212d70755 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py @@ -43,7 +43,6 @@ AttnProcessor2_0, XFormersAttnProcessor, ) -from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( @@ -74,7 +73,6 @@ ControlNetUnionModel, AutoencoderKL, ) - from diffusers.models.controlnets import ControlNetUnionInputProMax from diffusers.utils import load_image import torch from PIL import Image @@ -95,6 +93,7 @@ controlnet=controlnet, vae=vae, torch_dtype=torch.float16, + variant="fp16", ).to("cuda") # `enable_model_cpu_offload` is not recommended due to multiple generations height = image.height @@ -132,14 +131,12 @@ # set ControlNetUnion input result_images = [] for sub_img, crops_coords in zip(images, crops_coords_list): - union_input = ControlNetUnionInputProMax( - tile=sub_img, - ) new_width, new_height = W, H out = pipe( prompt=[prompt] * 1, image=sub_img, - control_image_list=union_input, + control_image=[sub_img], + control_mode=[6], width=new_width, height=new_height, num_inference_steps=30, @@ -1065,7 +1062,7 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, - control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, @@ -1090,6 +1087,7 @@ def __call__( guess_mode: bool = False, control_guidance_start: Union[float, List[float]] = 0.0, control_guidance_end: Union[float, List[float]] = 1.0, + control_mode: Optional[Union[int, List[int]]] = None, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, @@ -1119,10 +1117,7 @@ def __call__( `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): The initial image will be used as the starting point for the image generation process. Can also accept image latents as `image`, if passing latents directly, it will not be encoded again. - control_image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`): - In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, - `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, - `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):: + control_image (`PipelineImageInput`): The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height @@ -1291,53 +1286,47 @@ def __call__( controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)): - raise ValueError( - "Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`" - ) - if len(control_image_list) != controlnet.config.num_control_type: - if isinstance(control_image_list, ControlNetUnionInput): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`." - ) - elif isinstance(control_image_list, ControlNetUnionInputProMax): - raise ValueError( - f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`." - ) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): control_guidance_end = len(control_guidance_start) * [control_guidance_end] - # 1. Check inputs. Raise error if not correct - control_type = [] - for image_type in control_image_list: - if control_image_list[image_type]: - self.check_inputs( - prompt, - prompt_2, - control_image_list[image_type], - strength, - num_inference_steps, - callback_steps, - negative_prompt, - negative_prompt_2, - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ip_adapter_image, - ip_adapter_image_embeds, - controlnet_conditioning_scale, - control_guidance_start, - control_guidance_end, - callback_on_step_end_tensor_inputs, - ) - control_type.append(1) - else: - control_type.append(0) + if not isinstance(control_image, list): + control_image = [control_image] + + if not isinstance(control_mode, list): + control_mode = [control_mode] + + if len(control_image) != len(control_mode): + raise ValueError("Expected len(control_image) == len(control_type)") + + num_control_type = controlnet.config.num_control_type + + # 1. Check inputs + control_type = [0 for _ in range(num_control_type)] + for _image, control_idx in zip(control_image, control_mode): + control_type[control_idx] = 1 + self.check_inputs( + prompt, + prompt_2, + _image, + strength, + num_inference_steps, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) control_type = torch.Tensor(control_type) @@ -1397,21 +1386,19 @@ def __call__( # 4. Prepare image and controlnet_conditioning_image image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) - for image_type in control_image_list: - if control_image_list[image_type]: - control_image = self.prepare_control_image( - image=control_image_list[image_type], - width=width, - height=height, - batch_size=batch_size * num_images_per_prompt, - num_images_per_prompt=num_images_per_prompt, - device=device, - dtype=controlnet.dtype, - do_classifier_free_guidance=self.do_classifier_free_guidance, - guess_mode=guess_mode, - ) - height, width = control_image.shape[-2:] - control_image_list[image_type] = control_image + for idx, _ in enumerate(control_image): + control_image[idx] = self.prepare_control_image( + image=control_image[idx], + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = control_image[idx].shape[-2:] # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -1444,10 +1431,11 @@ def __call__( ) # 7.2 Prepare added time ids & embeddings - for image_type in control_image_list: - if isinstance(control_image_list[image_type], torch.Tensor): - original_size = original_size or control_image_list[image_type].shape[-2:] + original_size = original_size or (height, width) target_size = target_size or (height, width) + for _image in control_image: + if isinstance(_image, torch.Tensor): + original_size = original_size or _image.shape[-2:] if negative_original_size is None: negative_original_size = original_size @@ -1531,8 +1519,9 @@ def __call__( control_model_input, t, encoder_hidden_states=controlnet_prompt_embeds, - controlnet_cond=control_image_list, + controlnet_cond=control_image, control_type=control_type, + control_type_idx=control_mode, conditioning_scale=cond_scale, guess_mode=guess_mode, added_cond_kwargs=controlnet_added_cond_kwargs, diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index 01d29867dea3..24e31fa4cfc7 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -549,6 +549,8 @@ def check_inputs( prompt, prompt_2, prompt_3, + height, + width, strength, negative_prompt=None, negative_prompt_2=None, @@ -560,6 +562,15 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -730,6 +741,8 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, image: PipelineImageInput = None, strength: float = 0.6, num_inference_steps: int = 50, @@ -860,11 +873,15 @@ def __call__( [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, + height, + width, strength, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, @@ -933,7 +950,7 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Preprocess image - image = self.image_processor.preprocess(image) + image = self.image_processor.preprocess(image, height=height, width=width) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index c91b4ee80eaa..013c31c18e34 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -218,6 +218,9 @@ def __init__( ) self.tokenizer_max_length = self.tokenizer.model_max_length self.default_sample_size = self.transformer.config.sample_size + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -531,6 +534,8 @@ def check_inputs( prompt, prompt_2, prompt_3, + height, + width, strength, negative_prompt=None, negative_prompt_2=None, @@ -542,6 +547,15 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -710,6 +724,8 @@ def __call__( prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, prompt_3: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, image: PipelineImageInput = None, strength: float = 0.6, num_inference_steps: int = 50, @@ -824,12 +840,16 @@ def __call__( [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images. """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, prompt_3, + height, + width, strength, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, @@ -890,7 +910,7 @@ def __call__( pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) # 3. Preprocess image - image = self.image_processor.preprocess(image) + image = self.image_processor.preprocess(image, height=height, width=width) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index 43cb9e5ad0b6..2b6e42aa5081 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -224,6 +224,9 @@ def __init__( ) self.tokenizer_max_length = self.tokenizer.model_max_length self.default_sample_size = self.transformer.config.sample_size + self.patch_size = ( + self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2 + ) # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds def _get_t5_prompt_embeds( @@ -538,6 +541,8 @@ def check_inputs( prompt, prompt_2, prompt_3, + height, + width, strength, negative_prompt=None, negative_prompt_2=None, @@ -549,6 +554,15 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): + if ( + height % (self.vae_scale_factor * self.patch_size) != 0 + or width % (self.vae_scale_factor * self.patch_size) != 0 + ): + raise ValueError( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}." + f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}." + ) + if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -953,6 +967,8 @@ def __call__( prompt, prompt_2, prompt_3, + height, + width, strength, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a4757ac2f336..d83fa6201117 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -237,11 +237,8 @@ class StableDiffusionXLPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "negative_add_time_ids", ] def __init__( @@ -1243,13 +1240,8 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 50688ddb1cb8..126f25a41adc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -257,11 +257,8 @@ class StableDiffusionXLImg2ImgPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", ] def __init__( @@ -1438,13 +1435,8 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c7c706350e8e..a378ae65eb30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -285,11 +285,8 @@ class StableDiffusionXLInpaintPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", "mask", "masked_image_latents", ] @@ -1671,13 +1668,8 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) mask = callback_outputs.pop("mask", mask) masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index cb1514b153ce..1e285a9670e2 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -104,8 +104,8 @@ def __init__( self.use_pos_embed = use_pos_embed if self.use_pos_embed: - pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) - self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt") + self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False) def forward(self, latent): latent = self.proj(latent)