Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full SDXL Model #67

Merged
merged 30 commits into from
Oct 4, 2023
Merged

Full SDXL Model #67

merged 30 commits into from
Oct 4, 2023

Conversation

jazcollins
Copy link
Contributor

This PR contains the full implementation of Stable Diffusion XL (SDXL). SDXL uses two text encoders/tokenizers and also takes crop & size parameters from the dataloader as conditioning - a majority of the changes here are for supporting that.

A high-level description of the changes for each file:

diffusion/datasets/image_caption.py

  • Added rand_crop flag to choose between LargestCenterSquare & RandomCropSquare - previously only center cropping was supported. This is relevant to SD2 if one might want to train with random cropping, but doesn't apply to SDXL
  • Infer whether or not we're doing SDXL training by the tokenizer_name_or_path
  • Use RandomCropSquareReturnTransform for SDXL, which returns the cropping parameters used as well as original image size (for training SDXL with micro-conditioning) and return micro-conditioning as part of the training batch
  • Add option to do micro-conditioning dropout (by setting microcond_drop_prob flag). This is not discussed in the SDXL paper but is reflected in Stability AI's implementation
  • Small changes necessary for using SDXLTokenizer

diffusion/datasets/laion/transforms.py

  • Implementation of random cropping
  • RandomCropSquare (does random crop only) and RandomCropSquareReturnTransform (does random crop and returns crop params)

diffusion/models/layers.py

  • Contains attention processor classes for QKV clipping
  • Contains zero_module function used in SDXL init

diffusion/models/models.py

  • Add option to do QKV clipping for both SD2 and SDXL (clip_qkv argument)
  • For SDXL, instantiate SDXLTokenizer and SDXLTextEncoder which contain the two tokenizers/text encoders but mostly can be used as if they are one tokenizer/text encoder
  • For SDXL, do zero init trick

diffusion/models/stable_diffusion.py

  • Pass sdxl flag to StableDiffusion to indicate if we are training an SDXL model
  • Set appropriate latent_scale for SD2 vs. SDXL
  • Extract pooled_conditioning from SDXL text encoder, which is used in micro-conditioning
  • Construct micro-conditioning dict and pass it to the UNet forward call
  • In StableDiffusion.generate(...), allow user to pass crop_params and size_params for SDXL micro-conditioning, otherwise set them to reasonable default values
  • In StableDiffusion.generate(...), new flag called zero_out_negative_prompt that zero's out the negative prompt if it is empty (rather than tokenizing and encoding the empty string). This was added to match the behavior of the diffusers StableDiffusionXLPipeline and in general this just seems like a good thing to do. Note: I set the default value to be True, so this means previously made generations (e.g. with SD2) will look different despite using the same prompt/seed. Obviously can set it to False to match previous results.

diffusion/callbacks/log_diffusion_images.py

  • Edits to support multiple tokenizers

setup.py

  • Pin diffusers version to 0.21.0 - this is necessary because attention processor implementations are tied to this version

There are a few remaining things I'd like to add, but this is already a big enough PR. I will add the following once this one is merged in:

  • Zero-ing out dropped captions rather than tokenizing/encoding the empty string during training (this can apply to SD2 training as well)
  • Adding a CenterCropSquareReturnTransform transformation that can be used for COCO eval. Currently with SDXL training we do random crop for the eval dataset as well
  • Allow user to pass different prompts to the different text encoders in SDXL inference

Copy link
Contributor

@Landanjs Landanjs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great work!! It seems really difficult to manage the SDXL specific features and the SD2 features in one model, but you did the best organization possible. The few possible refactors may not be that much better. I proposed a few suggestions, but up to your discretion if you think they are good / worth the effort.

The only possible bug I noticed was in log_diffusion_images.py.

One general proposal: there are a few if-statements due to slight differences in the SDXL and HF tokenizers. If you're up for it, I think it would clean up a bit of code if could make them more similar. I think this requires two things:

  1. Having a max_length argument in the SDXL Tokenizers and in the code set max length by max_length = None if self.sdxl else self.tokenizer.model_max_length.
  2. Have the SDXL tokenizer return of dictionary with the key input_ids

diffusion/datasets/laion/transforms.py Outdated Show resolved Hide resolved
diffusion/datasets/image_caption.py Outdated Show resolved Hide resolved
diffusion/callbacks/log_diffusion_images.py Outdated Show resolved Hide resolved
diffusion/callbacks/log_diffusion_images.py Outdated Show resolved Hide resolved
diffusion/models/layers.py Outdated Show resolved Hide resolved
diffusion/models/layers.py Show resolved Hide resolved
diffusion/models/models.py Show resolved Hide resolved
diffusion/models/stable_diffusion.py Show resolved Hide resolved
diffusion/models/stable_diffusion.py Show resolved Hide resolved
diffusion/models/stable_diffusion.py Outdated Show resolved Hide resolved
Copy link
Contributor

@Landanjs Landanjs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing!! Thanks for all the fixes!

diffusion/models/models.py Show resolved Hide resolved
@jazcollins jazcollins merged commit 35f5a57 into mosaicml:main Oct 4, 2023
7 checks passed
torch_dtype = torch.float16 if encode_latents_in_fp16 else None
try:
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch_dtype)
except: # for handling SDXL vae fp16 fixed checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The blanket except isn't great here. We should probably qualify the exception type here.

attn_processor = ClippedXFormersAttnProcessor(clip_val=clip_qkv)
else:
attn_processor = ClippedAttnProcessor2_0(clip_val=clip_qkv)
log.info('Using %s with clip_val %.1f' % (attn_processor.__class__, clip_qkv))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the '%' sign, logger can automatically generate the strings by itself.

Copy link
Contributor

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great PR, just some minor nits

hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
assert channel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print the value of the assert, so we know why it's False (None, 0, etc.)

Suggested change
assert channel
assert channel, f"{channel}"

@felixdae
Copy link

why we need to clip qkv

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants