Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] SD3.5 IP-Adapter Pipeline Integration #9987

Merged
merged 55 commits into from
Dec 20, 2024

Conversation

guiyrt
Copy link
Contributor

@guiyrt guiyrt commented Nov 21, 2024

What does this PR do?

Integrates IP-Adapter for SD3.5 pipeline, as discussed in #9966.

Before submitting

Who can review?

@yiyixuxu
@sayakpaul
@DN6
@asomoza
@fabiorigano
@haofanwang

@sayakpaul sayakpaul requested a review from yiyixuxu November 22, 2024 01:02
@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 6, 2024

I have now a more robust integration of InstantX/SD3.5-Large-IP-Adapter, but there are still some things to clean-up and improve :) A few points that I noted:

  • After looking into IPAdapterMixin, I thought maybe it would be easier to create a new IP Adapter loader for SD3.5? It handle behavior related different types of encoder_hid_proj, which I think don't apply to SD3.5. Also, the old IP Adapter attention processors are not useful for SD3.5. If you agree, I can create a new loader class for SD3.5, which would have the load_ip_adapter method that is currently in the pipeline class (as placeholder), set_ip_adapter_scale and other relevant methods.
  • I used StableDiffusionXLPipeline as reference and noticed a difference in image encoding compared with the code from the InstantX team. In the SDXL pipeline, when output_hidden_states is True, torch.zeros_like for uncond_image_enc is passed through the image encoder, but in InstantX/SD3.5-Large-IP-Adapter.pipeline_stable_diffusion_3_ipa.py it is not. Which one should it be?
  • I changed the default argument of joint_attn_kwargs in the SD3.5 pipeline forward() to empty dict so that we don't need to check if is None before unpacking. I think it makes it simpler, but let me know if it should be reverted.
  • When IP Adapter scale=0, can we just ignore the IP Adapter? Adding zeros to skip image processing means doing image-related compute for nothing, so skipping that would be ideal. I'm not an expert yet on this, so let me know if this makes sense 😄
  • Is it valuable using .eval() and @torch.inference_mode()? I removed it because I didn't see it often in other pipelines.

A couple things needed to change before merging (I think):

  • num_images_per_prompt and batch_size are not considered yet.
  • The image_proj resampler is very similar to the existing IPAdapterPlusImageProjection, plus timestep embeddings. We could simplify TimePerceiverResampler to make use of existing implementations.
  • Extending IP Adapters to other SD3.5 pipelines when things are refined.
  • Add/update tests.

Also this (I had a few more points, but can share them later):

  • Not super into the feature_extractoras property name for IP Adapter image processor , but image_processor is already taken by VaeImageProcessor. Maybe ip_image_processor?
  • At the moment, only one IP Adapter can be loaded. Perhaps out of scope, but might be useful to use MultiIPAdapterImageProjection in the future (Happy to work on that if you find it valuable).

As I mentioned in the issue, as a proud member of the GPU poor community, I can't fit the pipeline in my puny 16GB GPU 😅
I tested the pipeline for 1 inference step (only took 5hr), for functionality check, but can't verify image generation quality.
If someone could try it or let me know of a way for me to run the pipeline, I would greatly appreciate it! Here is the inference code I used (is there a better way to share this?):

import torch
from PIL import Image

from diffusers.models.transformers import SD3Transformer2DModel
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
from transformers import SiglipVisionModel, SiglipImageProcessor

model_path = 'stabilityai/stable-diffusion-3.5-large'
image_encoder_path = "google/siglip-so400m-patch14-384"
ip_adapter_path = "InstantX/SD3.5-Large-IP-Adapter"

device = "cuda"
image_name = "image.png"

transformer = SD3Transformer2DModel.from_pretrained(
    model_path, subfolder="transformer", torch_dtype=torch.bfloat16
)

feature_extractor = SiglipImageProcessor.from_pretrained(
    image_encoder_path, torch_dtype=torch.bfloat16
)

image_encoder = SiglipVisionModel.from_pretrained(
    image_encoder_path, torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path,
    transformer=transformer,
    torch_dtype=torch.bfloat16,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder
).to(device)

pipe.load_ip_adapter(ip_adapter_path, subfolder="", weight_name="ip-adapter.bin")
pipe.set_ip_adapter_scale(0.6)

ref_img = Image.open(image_name).convert('RGB')

# please note that SD3.5 Large is sensitive to highres generation like 1536x1536
image = pipe(
    width=1024,
    height=1024,
    prompt='a cat',
    negative_prompt="lowres, low quality, worst quality",
    num_inference_steps=24,
    guidance_scale=5.0,
    generator=torch.Generator(device).manual_seed(42),
    ip_adapter_image=ref_img,
).images[0]

image.save('result.jpg')

@guiyrt guiyrt changed the title SD3.5 IP-Adapter Pipeline Integration [WIP] SD3.5 IP-Adapter Pipeline Integration Dec 9, 2024
@guiyrt guiyrt marked this pull request as ready for review December 9, 2024 10:58
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@hlky hlky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution @guiyrt

@hlky
Copy link
Collaborator

hlky commented Dec 9, 2024

@guiyrt Could you share some example outputs here?

@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 10, 2024

Thanks @hlky for fixing the tests 😃
I managed to test the pipeline with a big enough GPU, and found there was a bug in loading the IP adapter state dict, but it's fixed now. I ran the pipeline with some images, and took the chance to test with num_images_per_prompt=4. Here are some examples, with the same text prompt as in the official examples, 'a cat'. The first two images are from the assets folder in InstantX/SD3.5-Large-IP-Adapter.
grid

@hlky
Copy link
Collaborator

hlky commented Dec 10, 2024

Thanks! I'll go through your comments:

After looking into IPAdapterMixin, I thought maybe it would be easier to create a new IP Adapter loader for SD3.5?

Yes makes sense, can create a new loader class for SD3.5.

In the SDXL pipeline, when output_hidden_states is True, torch.zeros_like for uncond_image_enc is passed through the image encoder, but in InstantX/SD3.5-Large-IP-Adapter.pipeline_stable_diffusion_3_ipa.py it is not. Which one should it be?

Let's follow the original code for InstantX/SD3.5-Large-IP-Adapter.

I changed the default argument of joint_attn_kwargs in the SD3.5 pipeline forward() to empty dict so that we don't need to check if is None before unpacking. I think it makes it simpler, but let me know if it should be reverted.

Already resolved.

When IP Adapter scale=0, can we just ignore the IP Adapter? Adding zeros to skip image processing means doing image-related compute for nothing, so skipping that would be ideal.

Yes I think we can do that.

Is it valuable using .eval() and @torch.inference_mode()? I removed it because I didn't see it often in other pipelines.

Should be ok because from_pretrained sets .eval()

num_images_per_prompt and batch_size are not considered yet.

Looks like this is resolved, will check it.

The image_proj resampler is very similar to the existing IPAdapterPlusImageProjection, plus timestep embeddings. We could simplify TimePerceiverResampler to make use of existing implementations.

It would be great if we can use existing modules

Not super into the feature_extractoras property name for IP Adapter image processor , but image_processor is already taken by VaeImageProcessor. Maybe ip_image_processor?

feature_extractor is ok, same as for other IP Adapters

At the moment, only one IP Adapter can be loaded. Perhaps out of scope, but might be useful to use MultiIPAdapterImageProjection in the future

Can always be done in a follow-up PR, I don't know how common it is to use multiple IP adapter images anyway.

I can't fit the pipeline in my puny 16GB GPU

You can try pipe.enable_model_cpu_offload() or pipe.enable_sequential_cpu_offload().

@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 18, 2024

Thanks for adding more docs! For images, you can open a PR here and I can merge it for you :)

Thanks! Just opened it https://huggingface.co/datasets/huggingface/documentation-images/discussions/404

@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 19, 2024

I was trying to figure out why enable_sequential_cpu_offload() or enable_model_cpu_offload() would only work with _exclude_from_cpu_offload=["image_encoder"], and I think the root cause is the PyTorch implementation of MultiHeadAttention. As you can verify in torch.nn.modules.activation.py, the arguments self.out_proj.weight and self.out_proj.bias are accessed directly, and not in their classes' forward(), which means AlignDevicesHook doesn't move that data to the GPU (and boom, device mismatch). This happens because unlike in_proj_weight and in_proj_bias, which are nn.Parameter, out_proj is of type NonDynamicallyQuantizableLinear. Apparently this is to avoid improper quantization [note in class implementation, pytorch/pytorch#58969] and is supposed to be removed at some point.

Maybe this was already known before I spiraled into this rabbit hole, but the easy fix for now is to always keep image_encoder on the GPU. Maybe there's an easy way to just keep all NonDynamicallyQuantizableLinear on GPU, that could be a compromise for a patch, but I'm not sure if CPU offloading is that crucial to go that far.

@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 19, 2024

Oh, I should mention that this started because I was also breaking the SD3Transformer2DModel.forward() encapsulation by accessing self.transformer.image_proj from StableDiffusion3Pipeline.__call__. Now the image embeddings from image_encoder are passed to transformer, and image_proj is called from there (similarly to SDXL).

Let me know if I should change something, as this is changing the signature of SD3Transformer2DModel.forward(). I'm a bit unsure if I should change the argument name ip_adapter_image_embeds to ip_adapter_image_hidden_states for consistency, as all the other embeddings variables have the suffix hidden_states.

@hlky
Copy link
Collaborator

hlky commented Dec 19, 2024

Would you mind posting the traceback? Maybe there's something we can do but if the issue is in Siglip we may need to raise it with Transformers team. We probably don't want to keep Siglip on GPU, it's relatively heavy like CLIP Vision right?

Has passing image_embeds fixed the issue or no? I think we can still pass it through kwargs instead of changing the signature.

@hlky hlky mentioned this pull request Dec 19, 2024
@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 19, 2024

Would you mind posting the traceback? Maybe there's something we can do but if the issue is in Siglip we may need to raise it with Transformers team. We probably don't want to keep Siglip on GPU, it's relatively heavy like CLIP Vision right?

This happens with enable_sequential_cpu_offload() but not enable_model_cpu_offload() (I can only verify enable_model_cpu_offload() until the transformer is called, then I get OOM error). As enable_model_cpu_offload()works on model level, it copies the entire image_encoder to the GPU and no issues there. But enable_sequential_cpu_offload() works on submodule level, and only moves to GPU the parameters of that submodule when its forward() is called. In MultiHeadAttention, out_proj.weightand out_proj.bias are parameters of out_proj, so they would only be moved to GPU when out_proj.forward() was invoked, which doesn't happend. Instead, these are accessed directly outside the expected scope.

SigLIP from "google/siglip-so400m-patch14-384" has about 430M params and takes about 1GB of VRAM in torch.float16 (quick testing, just loaded to GPU).

Traceback when trying to include image_encoder in CPU offloading
Traceback (most recent call last):
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/guiyrt/diffusers/run_test.py", line 41, in <module>
    images = pipe(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 1028, in __call__
    ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 750, in prepare_ip_adapter_image_embeds
    single_image_embeds = self.encode_image(ip_adapter_image, device)
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 715, in encode_image
    return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1190, in forward
    return self.vision_model(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1101, in forward
    pooler_output = self.head(last_hidden_state) if self.use_head else None
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1128, in forward
    hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/activation.py", line 1368, in forward
    attn_output, attn_output_weights = F.multi_head_attention_forward(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/functional.py", line 6251, in multi_head_attention_forward
    attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 83, in inner
    r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 1525, in addmm
    return out + beta * self
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
    result = fn(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 141, in _fn
    result = fn(**bound.arguments)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_refs/__init__.py", line 1099, in add
    output = prims.add(a, b)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_library/fake_impl.py", line 93, in meta_kernel
    return fake_impl_holder.kernel(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_library/utils.py", line 20, in __call__
    return self.func(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/library.py", line 1151, in inner
    return func(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_library/custom_ops.py", line 614, in fake_impl
    return self._abstract_fn(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims/__init__.py", line 402, in _prim_elementwise_meta
    utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 742, in check_same_device
    raise RuntimeError(msg)
RuntimeError: Tensor on device meta is not on the expected device cuda:0!
Code to reproduce

Make sure to comment out `_exclude_from_cpu_offload` in `StableDiffusion3Pipeline` (line 186)

import torch
from PIL import Image

from diffusers import StableDiffusion3Pipeline
from transformers import SiglipVisionModel, SiglipImageProcessor

model_path = "stabilityai/stable-diffusion-3.5-large"
image_encoder_path = "google/siglip-so400m-patch14-384"
ip_adapter_path = "InstantX/SD3.5-Large-IP-Adapter"

feature_extractor = SiglipImageProcessor.from_pretrained(
    image_encoder_path, torch_dtype=torch.bfloat16
)

image_encoder = SiglipVisionModel.from_pretrained(
    image_encoder_path, torch_dtype=torch.bfloat16
)

pipe = StableDiffusion3Pipeline.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    feature_extractor=feature_extractor,
    image_encoder=image_encoder,
)
pipe.load_ip_adapter(ip_adapter_path, revision="f1f54ca369ae759f9278ae9c87d46def9f133c78")
pipe.set_ip_adapter_scale(0.6)
pipe.enable_sequential_cpu_offload()

ref_img = Image.open("image.jpg").convert('RGB')

# please note that SD3.5 Large is sensitive to highres generation like 1536x1536
image = pipe(
    width=1024,
    height=1024,
    prompt="a cat",
    negative_prompt="lowres, low quality, worst quality",
    num_inference_steps=24,
    guidance_scale=5.0,
    generator=torch.manual_seed(42),
    ip_adapter_image=ref_img
).images[0]

image.save("result.jpg")

Has passing image_embeds fixed the issue or no? I think we can still pass it through kwargs instead of changing the signature.

This happened with enable_model_cpu_offload() but not enable_sequential_cpu_offload(). As enable_model_cpu_offload() moves entire models to GPU when their forward() is called, image_proj as part of transformer model, would only be moved to GPU when transformer() is called, and I was accessing transformer.image_proj before that. It wasn't a problem with enable_sequential_cpu_offload() because we were still calling the forward() of image_proj, and enable_sequential_cpu_offload() works on submodule level, not model.

Yes, it fixed the issue with pipe.enable_model_cpu_offload(). I thought of using kwargs at first as well, but joint_attention_kwargs is passed all the way to the attention processor, which only expects ip_hidden_states and temb, not the image embeds from image encoder. This raises a warning when unexpected kwargs are passed to the attention processor, so if we do it that way, we need to remove it or pass new kwargs without the image encoder embeds. Unless you mean to create **kwargs for SD3Transformer2DModel.forward()?

But as long as we don't access self.transformer.image_proj directly from StableDiffusion3Pipeline, it now works with using pipe.enable_model_cpu_offload().

Traceback accessing `image_proj` from `StableDiffusion3Pipeline` (fixed in last commit)
Traceback (most recent call last):
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/guiyrt/diffusers/run_test.py", line 41, in <module>
    images = pipe(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 1049, in __call__
    ip_hidden_states, temb = self.transformer.image_proj(
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/models/embeddings.py", line 2564, in forward
    timestep_emb = self.time_embedding(timestep_emb)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/diffusers/src/diffusers/models/embeddings.py", line 1304, in forward
    sample = self.linear_1(sample)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/guiyrt/anaconda3/envs/diffusers_9966/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 125, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

TL;DR:

  • enable_model_cpu_offload() works (as far as I can test)
  • enable_sequential_cpu_offload() works only with image_encoder excluded from CPU offload

@hlky
Copy link
Collaborator

hlky commented Dec 19, 2024

Awesome, thanks for checking into this. I think the enable_sequential_cpu_offload case is not a blocker, I'll discuss it with the team, I think we'll need to raise it with Transformers team or Accelerate team, let's add a note in the docstring to set _exclude_from_cpu_offload = ["image_encoder"] when using enable_sequential_cpu_offload, the attribute is available by default on all pipelines so pipe._exclude_from_cpu_offload = ["image_encoder"] or pipe._exclude_from_cpu_offload.append("image_encoder") should work.

As for joint_attention_kwargs, let's add ip_adapter_image_embeds from the pipeline pop in SD3Transformer2DModel then add ip_hidden_states.

@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 19, 2024

Awesome, thanks for checking into this. I think the enable_sequential_cpu_offload case is not a blocker, I'll discuss it with the team, I think we'll need to raise it with Transformers team or Accelerate team, let's add a note in the docstring to set _exclude_from_cpu_offload = ["image_encoder"] when using enable_sequential_cpu_offload, the attribute is available by default on all pipelines so pipe._exclude_from_cpu_offload = ["image_encoder"] or pipe._exclude_from_cpu_offload.append("image_encoder") should work.

As enable_sequential_cpu_offload() is required to run this model in almost every consumer GPU due to VRAM requirements, it might help the average user if a warning is displayed, as a lot of people might miss the note in the docstring. What do you think if we do something like this? We can also write something in the docstring.

def enable_sequential_cpu_offload(self, *args, **kwargs):
    if "image_encoder" not in self._exclude_from_cpu_offload:
        logger.warning(
            "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
            "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
            "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
        )

    super().enable_sequential_cpu_offload(*args, **kwargs)

As for joint_attention_kwargs, let's add ip_adapter_image_embeds from the pipeline pop in SD3Transformer2DModel then add ip_hidden_states.

Makes sense, I was afraid that when skip_guidance_layers is not None and self.transformer is called again, the pop would remove ip_adapter_image_embeds, but SD3Transformer2DModel.forward() creates a copy of joint_attention_kwargs, so no harm. I'll make that change.

@hlky
Copy link
Collaborator

hlky commented Dec 19, 2024

Sure, let's add that warning, nice idea. Thank you once again for all the iterations on this.

@guiyrt
Copy link
Contributor Author

guiyrt commented Dec 19, 2024

Sure, let's add that warning, nice idea. Thank you once again for all the iterations on this.

I enjoyed and learned a lot over the course of this PR, thanks a lot for the guidance @hlky @yiyixuxu @stevhliu :) Not bad for a first PR ahah

Unless we need to update any pipeline tests or you have more change suggestions, I'd say we're golden! We can add IP-Adapters to the rest of the SD3 pipelines, but that could maybe be a different PR, I've seen interest to use it especially with controlnet pipelines (#10129). If that's up for grabs, I'm happy to go for it :)

We should try to merge Update checkpoints according to diffusers integration also once this is merged, so the checkpoints can be used.

@yiyixuxu
Copy link
Collaborator

I think we can allow the user pass the _exclude_from_cpu_offload and model_cpu_offload_seq as variables to these offloading methods ( we can do in a separate PR)

@yiyixuxu yiyixuxu merged commit 3191248 into huggingface:main Dec 20, 2024
11 of 12 checks passed
@guiyrt guiyrt deleted the sd3.5_IPAdapter branch December 20, 2024 11:51
Foundsheep pushed a commit to Foundsheep/diffusers that referenced this pull request Dec 23, 2024
* Added support for single IPAdapter on SD3.5 pipeline



---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
* Added support for single IPAdapter on SD3.5 pipeline



---------

Co-authored-by: hlky <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
roadmap Add to current release roadmap
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants