Skip to content

Commit

Permalink
Add partial SDXL model (#61)
Browse files Browse the repository at this point in the history
* add sdxl unet

* fix stochastic failures in streaming datasets

* add some debug logging

* unpin some reqs

* add yamls

* remove debug prints

* allow passing vae model path

* add base

* remove trailing whitespace

* split sdxl into separate model

* remove local yamls

* clean up sd2 doc

* one more doc fix

* add NotImplementedError, fix docs
  • Loading branch information
jazcollins authored Aug 21, 2023
1 parent a735ae7 commit bacab36
Showing 1 changed file with 123 additions and 6 deletions.
129 changes: 123 additions & 6 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def stable_diffusion_2(
prompts.
Args:
model_name (str, optional): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool, optional): Whether to load pretrained weights. Defaults to True.
model_name (str): Name of the model to load. Defaults to 'stabilityai/stable-diffusion-2-base'.
pretrained (bool): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
Expand All @@ -54,12 +54,12 @@ def stable_diffusion_2(
[MeanSquaredError(), FrechetInceptionDistance(normalize=True)].
val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to
[1.0, 3.0, 7.0].
val_seed (int, optional): Seed to use for generating evaluation images. Defaults to 1138.
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to
[(0, 1)].
precomputed_latents (bool, optional): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool, optional): Whether to encode latents in fp16. Defaults to True.
fsdp (bool, optional): Whether to use FSDP. Defaults to True.
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
fsdp (bool): Whether to use FSDP. Defaults to True.
"""
if train_metrics is None:
train_metrics = [MeanSquaredError()]
Expand Down Expand Up @@ -123,6 +123,123 @@ def stable_diffusion_2(
return model


def stable_diffusion_xl(
model_name: str = 'stabilityai/stable-diffusion-2-base',
unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0',
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
pretrained: bool = True,
prediction_type: str = 'epsilon',
train_metrics: Optional[List] = None,
val_metrics: Optional[List] = None,
val_guidance_scales: Optional[List] = None,
val_seed: int = 1138,
loss_bins: Optional[List] = None,
precomputed_latents: bool = False,
encode_latents_in_fp16: bool = True,
fsdp: bool = True,
):
"""Stable diffusion 2 training setup + SDXL UNet and VAE.
Requires batches of matched images and text prompts to train. Generates images from text
prompts. Currently uses UNet and VAE config from SDXL, but text encoder/tokenizer from SD2.
Args:
model_name (str): Name of the model to load. Determines the text encoder, tokenizer,
and noise scheduler. Defaults to 'stabilityai/stable-diffusion-2-base'.
unet_model_name (str): Name of the UNet model to load. Defaults to
'stabilityai/stable-diffusion-xl-base-1.0'.
vae_model_name (str): Name of the VAE model to load. Defaults to
'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from
'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16.
pretrained (bool): Whether to load pretrained weights. Defaults to True.
prediction_type (str): The type of prediction to use. Must be one of 'sample',
'epsilon', or 'v_prediction'. Default: `epsilon`.
train_metrics (list, optional): List of metrics to compute during training. If None, defaults to
[MeanSquaredError()].
val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to
[MeanSquaredError(), FrechetInceptionDistance(normalize=True)].
val_guidance_scales (list, optional): List of scales to use for validation guidance. If None, defaults to
[1.0, 3.0, 7.0].
val_seed (int): Seed to use for generating evaluation images. Defaults to 1138.
loss_bins (list, optional): List of tuples of (min, max) values to use for loss binning. If None, defaults to
[(0, 1)].
precomputed_latents (bool): Whether to use precomputed latents. Defaults to False.
encode_latents_in_fp16 (bool): Whether to encode latents in fp16. Defaults to True.
fsdp (bool): Whether to use FSDP. Defaults to True.
"""
if train_metrics is None:
train_metrics = [MeanSquaredError()]
if val_metrics is None:
val_metrics = [MeanSquaredError(), FrechetInceptionDistance(normalize=True)]
if val_guidance_scales is None:
val_guidance_scales = [1.0, 3.0, 7.0]
if loss_bins is None:
loss_bins = [(0, 1)]
# Fix a bug where CLIPScore requires grad
for metric in val_metrics:
if isinstance(metric, CLIPScore):
metric.requires_grad_(False)

if pretrained:
raise NotImplementedError('Full SDXL pipeline not implemented yet.')
else:
config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')
# Currently not doing micro-conditioning, so set config appropriately
config[0]['addition_embed_type'] = None
config[0]['cross_attention_dim'] = 1024
unet = UNet2DConditionModel(**config[0])

# Prevent fsdp from wrapping up_blocks and down_blocks because the forward pass calls length on these
unet.up_blocks._fsdp_wrap = False
unet.down_blocks._fsdp_wrap = False
for block in unet.up_blocks:
block._fsdp_wrap = True
for block in unet.down_blocks:
block._fsdp_wrap = True

if encode_latents_in_fp16:
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16)
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder', torch_dtype=torch.float16)
else:
vae = AutoencoderKL.from_pretrained(vae_model_name)
text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder='text_encoder')

tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
noise_scheduler = DDPMScheduler.from_pretrained(model_name, subfolder='scheduler')
inference_noise_scheduler = DDIMScheduler(num_train_timesteps=noise_scheduler.config.num_train_timesteps,
beta_start=noise_scheduler.config.beta_start,
beta_end=noise_scheduler.config.beta_end,
beta_schedule=noise_scheduler.config.beta_schedule,
trained_betas=noise_scheduler.config.trained_betas,
clip_sample=noise_scheduler.config.clip_sample,
set_alpha_to_one=noise_scheduler.config.set_alpha_to_one,
prediction_type=prediction_type)

model = StableDiffusion(
unet=unet,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
noise_scheduler=noise_scheduler,
inference_noise_scheduler=inference_noise_scheduler,
prediction_type=prediction_type,
train_metrics=train_metrics,
val_metrics=val_metrics,
val_guidance_scales=val_guidance_scales,
val_seed=val_seed,
loss_bins=loss_bins,
precomputed_latents=precomputed_latents,
encode_latents_in_fp16=encode_latents_in_fp16,
fsdp=fsdp,
)
if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
if is_xformers_installed:
model.unet.enable_xformers_memory_efficient_attention()
model.vae.enable_xformers_memory_efficient_attention()
return model


def discrete_pixel_diffusion(clip_model_name: str = 'openai/clip-vit-large-patch14', prediction_type='epsilon'):
"""Discrete pixel diffusion training setup.
Expand Down

0 comments on commit bacab36

Please sign in to comment.