From afa6c662fbd2f4ddc22376609ab49e16bbce9877 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Thu, 3 Oct 2024 17:02:33 -0700 Subject: [PATCH] Add composer model class for running with precomputed CLIP and T5 text latents (#171) --- diffusion/callbacks/log_diffusion_images.py | 35 +- diffusion/datasets/image_caption_latents.py | 47 +- diffusion/models/__init__.py | 7 +- diffusion/models/models.py | 305 ++++++++++- .../precomputed_text_latent_diffusion.py | 516 ++++++++++++++++++ 5 files changed, 886 insertions(+), 24 deletions(-) create mode 100644 diffusion/models/precomputed_text_latent_diffusion.py diff --git a/diffusion/callbacks/log_diffusion_images.py b/diffusion/callbacks/log_diffusion_images.py index c36b75d9..522c71cd 100644 --- a/diffusion/callbacks/log_diffusion_images.py +++ b/diffusion/callbacks/log_diffusion_images.py @@ -39,6 +39,11 @@ class LogDiffusionImages(Callback): use_table (bool): Whether to make a table of the images or not. Default: ``False``. t5_encoder (str, optional): path to the T5 encoder to as a second text encoder. clip_encoder (str, optional): path to the CLIP encoder as the first text encoder. + t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``. + t5_mask_key: (str): key to use for the T5 attention mask in the batch. Default: ``'T5_ATTENTION_MASK'``. + clip_latent_key: (str): key to use for the CLIP latents in the batch. Default: ``'CLIP_LATENTS'``. + clip_mask_key: (str): key to use for the CLIP attention mask in the batch. Default: ``'CLIP_ATTENTION_MASK'``. + clip_pooled_key: (str): key to use for the CLIP pooled in the batch. Default: ``'CLIP_POOLED'``. cache_dir: (str, optional): path for HF to cache files while downloading model """ @@ -53,6 +58,11 @@ def __init__(self, use_table: bool = False, t5_encoder: Optional[str] = None, clip_encoder: Optional[str] = None, + t5_latent_key: str = 'T5_LATENTS', + t5_mask_key: str = 'T5_ATTENTION_MASK', + clip_latent_key: str = 'CLIP_LATENTS', + clip_mask_key: str = 'CLIP_ATTENTION_MASK', + clip_pooled_key: str = 'CLIP_POOLED', cache_dir: Optional[str] = '/tmp/hf_files'): self.prompts = prompts self.size = (size, size) if isinstance(size, int) else size @@ -61,6 +71,11 @@ def __init__(self, self.rescaled_guidance = rescaled_guidance self.seed = seed self.use_table = use_table + self.t5_latent_key = t5_latent_key + self.t5_mask_key = t5_mask_key + self.clip_latent_key = clip_latent_key + self.clip_mask_key = clip_mask_key + self.clip_pooled_key = clip_pooled_key self.cache_dir = cache_dir # Batch prompts @@ -120,10 +135,11 @@ def __init__(self, clip_pooled = clip_outputs[1].cpu() clip_attention_mask = clip_attention_mask.cpu().to(torch.long) - latent_batch['T5_LATENTS'] = t5_latents - latent_batch['CLIP_LATENTS'] = clip_latents - latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1) - latent_batch['CLIP_POOLED'] = clip_pooled + latent_batch[self.t5_latent_key] = t5_latents + latent_batch[self.t5_mask_key] = t5_attention_mask + latent_batch[self.clip_latent_key] = clip_latents + latent_batch[self.clip_mask_key] = clip_attention_mask + latent_batch[self.clip_pooled_key] = clip_pooled self.batched_latents.append(latent_batch) del t5_model @@ -143,12 +159,11 @@ def eval_start(self, state: State, logger: Logger): all_gen_images = [] if self.precomputed_latents: for batch in self.batched_latents: - pooled_prompt = batch['CLIP_POOLED'].cuda() - prompt_mask = batch['ATTENTION_MASK'].cuda() - t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda()) - clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda()) - prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1) - + pooled_prompt = batch[self.clip_pooled_key].cuda() + prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(), + batch[self.clip_latent_key].cuda(), + batch[self.t5_mask_key].cuda(), + batch[self.clip_mask_key].cuda()) gen_images = model.generate(prompt_embeds=prompt_embeds, pooled_prompt=pooled_prompt, prompt_mask=prompt_mask, diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index fa1158f5..01ae1e8d 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -14,7 +14,8 @@ from torch.utils.data import DataLoader from torchvision import transforms -from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare +from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform, + RandomCropBucketedAspectRatioTransform, RandomCropSquare) from diffusion.datasets.utils import make_streams log = logging.getLogger(__name__) @@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. @@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``. + drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ @@ -53,10 +56,12 @@ def __init__( image_key: str = 'image', caption_keys: Tuple[str, ...] = ('caption',), caption_selection_probs: Tuple[float, ...] = (1.0,), + aspect_ratio_bucket_key: Optional[str] = None, text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), latent_dtype: torch.dtype = torch.bfloat16, + drop_nans: bool = True, **streaming_kwargs, ): @@ -72,10 +77,14 @@ def __init__( self.image_key = image_key self.caption_keys = caption_keys self.caption_selection_probs = caption_selection_probs + self.aspect_ratio_bucket_key = aspect_ratio_bucket_key + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform' self.text_latent_keys = text_latent_keys self.text_latent_shapes = text_latent_shapes self.attention_mask_keys = attention_mask_keys self.latent_dtype = latent_dtype + self.drop_nans = drop_nans def __getitem__(self, index): sample = super().__getitem__(index) @@ -90,15 +99,16 @@ def __getitem__(self, index): out['cond_original_size'] = torch.tensor(img.size) # Image transforms - if self.crop is not None: + if isinstance(self.crop, RandomCropBucketedAspectRatioTransform): + img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key]) + elif self.crop is not None: img, crop_top, crop_left = self.crop(img) else: crop_top, crop_left = 0, 0 - out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) - if self.transform is not None: img = self.transform(img) out['image'] = img + out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left]) # Get the new height and width if isinstance(img, torch.Tensor): @@ -140,6 +150,13 @@ def __getitem__(self, index): if 'CLIP_LATENTS' in latent_key: clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy() out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1]) + if self.drop_nans: + for latent_key, attn_key in zip(self.text_latent_keys, self.attention_mask_keys): + if out[latent_key].isnan().any(): + out[latent_key] = torch.zeros_like(out[latent_key]) + out[attn_key] = torch.zeros_like(out[attn_key]) + if 'CLIP_LATENTS' in latent_key and out['CLIP_POOLED'].isnan().any(): + out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED']) return out @@ -160,6 +177,7 @@ def build_streaming_image_caption_latents_dataloader( text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), latent_dtype: str = 'torch.bfloat16', + aspect_ratio_bucket_key: Optional[str] = None, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -178,11 +196,12 @@ def build_streaming_image_caption_latents_dataloader( ``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected. Default: ``None``. transform (Callable, optional): The transforms to apply to the image. Default: ``None``. - crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio']. + crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']. Default: ``'square'``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``. caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``. text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset. Default: ``('T5_LATENTS', 'CLIP_LATENTS')``. text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset. @@ -192,18 +211,22 @@ def build_streaming_image_caption_latents_dataloader( Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32', or 'torch.bfloat16'. Default: ``'torch.bfloat16'``. + aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. + Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ # Check crop type if crop_type is not None: crop_type = crop_type.lower() - if crop_type not in ['square', 'random', 'aspect_ratio']: - raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]') - if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)): + if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']: raise ValueError( - 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') - + f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]' + ) + if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or + isinstance(resize_size[0], int)): + raise ValueError( + 'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.') # Check latent dtype dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16} assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}' @@ -225,6 +248,9 @@ def build_streaming_image_caption_latents_dataloader( crop = RandomCropSquare(resize_size) elif crop_type == 'aspect_ratio': crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore + elif crop_type == 'bucketed_aspect_ratio': + assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type' + crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore else: crop = None @@ -242,6 +268,7 @@ def build_streaming_image_caption_latents_dataloader( image_key=image_key, caption_keys=caption_keys, caption_selection_probs=caption_selection_probs, + aspect_ratio_bucket_key=aspect_ratio_bucket_key, text_latent_keys=text_latent_keys, text_latent_shapes=text_latent_shapes, attention_mask_keys=attention_mask_keys, diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 1a4bb0a2..69cf02bd 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -4,10 +4,11 @@ """Diffusion models.""" from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion, - discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, - text_to_image_transformer) + discrete_pixel_diffusion, precomputed_text_latent_diffusion, stable_diffusion_2, + stable_diffusion_xl, text_to_image_transformer) from diffusion.models.noop import NoOpModel from diffusion.models.pixel_diffusion import PixelDiffusion +from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion from diffusion.models.stable_diffusion import StableDiffusion __all__ = [ @@ -17,8 +18,10 @@ 'discrete_pixel_diffusion', 'NoOpModel', 'PixelDiffusion', + 'precomputed_text_latent_diffusion', 'stable_diffusion_2', 'stable_diffusion_xl', 'StableDiffusion', + 'PrecomputedTextLatentDiffusion', 'text_to_image_transformer', ] diff --git a/diffusion/models/models.py b/diffusion/models/models.py index a6d782d8..33b02f8d 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -5,19 +5,20 @@ import logging import math -from typing import List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from composer.devices import DeviceGPU from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel from peft import LoraConfig from torchmetrics import MeanSquaredError -from transformers import CLIPTextModel, CLIPTokenizer, PretrainedConfig +from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, PretrainedConfig from diffusion.models.autoencoder import (AutoEncoder, AutoEncoderLoss, ComposerAutoEncoder, ComposerDiffusersAutoEncoder, load_autoencoder) from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion +from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer @@ -580,6 +581,306 @@ def stable_diffusion_xl( return model +def precomputed_text_latent_diffusion( + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + include_text_encoders: bool = False, + text_encoder_dtype: str = 'bfloat16', + cache_dir: str = '/tmp/hf_files', + prediction_type: str = 'epsilon', + image_key: str = 'image', + t5_latent_key: str = 'T5_LATENTS', + t5_mask_key: str = 'T5_ATTENTION_MASK', + clip_latent_key: str = 'CLIP_LATENTS', + clip_mask_key: str = 'CLIP_ATTENTION_MASK', + clip_pooled_key: str = 'CLIP_POOLED', + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + text_embed_dim: int = 4096, + train_noise_scheduler_params: Optional[Dict[str, Any]] = None, + inference_noise_scheduler_params: Optional[Dict[str, Any]] = None, + scheduler_shift_resolution: int = 256, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + fsdp: bool = True, + use_xformers: bool = True, + lora_rank: Optional[int] = None, + lora_alpha: Optional[int] = None, +): + """Latent diffusion model training using precomputed text latents from T5-XXL and CLIP. + + Args: + unet_model_name (str): Name of the UNet model to load. Defaults to + 'stabilityai/stable-diffusion-xl-base-1.0'. + vae_model_name (str): Name of the VAE model to load. Defaults to + 'madebyollin/sdxl-vae-fp16-fix' as the official VAE checkpoint (from + 'stabilityai/stable-diffusion-xl-base-1.0') is not compatible with fp16. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + include_text_encoders (bool): Whether to include text encoders in the model. Should only do this for running + inference. Default: `False`. + text_encoder_dtype (str): The dtype to use for the text encoder. One of [`float32`, `float16`, `bfloat16`]. + Default: `bfloat16`. + cache_dir (str): Directory to cache the model in if using `include_text_encoders`. Default: `'/tmp/hf_files'`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + image_key (str): The key to use for the image in the precomputed latents. Default: `'image'`. + t5_latent_key (str): The key to use for the T5 latents in the precomputed latents. Default: `'T5_LATENTS'`. + t5_mask_key (str): The key to use for the T5 attention mask in the precomputed latents. Default: `'T5_ATTENTION_MASK'`. + clip_latent_key (str): The key to use for the CLIP latents in the precomputed latents. Default: `'CLIP_LATENTS'`. + clip_mask_key (str): The key to use for the CLIP attention mask in the precomputed latents. Default: `'CLIP_ATTENTION_MASK'`. + clip_pooled_key (str): The key to use for the CLIP pooled in the precomputed latents. Default: `'CLIP_POOLED'`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + text_embed_dim (int): The dimension to project the text embeddings to. Default: `4096`. + train_noise_scheduler_params (Dict): Parameters to overried in the training noise scheduler. Anything not + specified will default to SDXL values. Default: `None`. + inference_noise_scheduler_params (Dict): Parameters to overried in the inference noise scheduler. Anything + not specified will default to SDXL values. Default: `None`. + scheduler_shift_resolution (int): The resolution to shift the noise scheduler to. Default: `256`. + train_metrics (list, optional): List of metrics to compute during training. If None, defaults to + [MeanSquaredError()]. + val_metrics (list, optional): List of metrics to compute during validation. If None, defaults to + [MeanSquaredError()]. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating evaluation images. Defaults to 1138. + fsdp (bool): Whether to use FSDP. Defaults to True. + use_xformers (bool): Whether to use xformers for attention. Defaults to True. + lora_rank (int, optional): If not None, the rank to use for LoRA finetuning. Defaults to None. + lora_alpha (int, optional): If not None, the alpha to use for LoRA finetuning. Defaults to None. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if train_metrics is None: + train_metrics = [MeanSquaredError()] + if val_metrics is None: + val_metrics = [MeanSquaredError()] + + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=torch.float16) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=torch.float16) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=torch.float16) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the unet + unet_config = PretrainedConfig.get_config_dict(unet_model_name, subfolder='unet')[0] + + if isinstance(vae, AutoEncoder): + # Adapt the unet config to account for differing number of latent channels if necessary + unet_config['in_channels'] = vae.config['latent_channels'] + unet_config['out_channels'] = vae.config['latent_channels'] + unet_config['cross_attention_dim'] = text_embed_dim + # This config variable is the sum of the text encoder projection dimension and + # the number of additional time embeddings (6) * addition_time_embed_dim (256) + unet_config['projection_class_embeddings_input_dim'] = 2304 + # Init the unet from the config + unet = UNet2DConditionModel(**unet_config) + + # Zero initialization trick + for name, layer in unet.named_modules(): + # Final conv in ResNet blocks + if name.endswith('conv2'): + layer = zero_module(layer) + # proj_out in attention blocks + if name.endswith('to_out.0'): + layer = zero_module(layer) + # Last conv block out projection + unet.conv_out = zero_module(unet.conv_out) + + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * unet_config['in_channels'] + if isinstance(latent_std, float): + latent_std = (latent_std,) * unet_config['in_channels'] + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + + # FSDP Wrapping Scheme + if hasattr(unet, 'mid_block') and unet.mid_block is not None: + for attention in unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + # Make the noise schedulers + train_scheduler_params: Dict[str, Any] = { + 'num_train_timesteps': 1000, + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'beta_schedule': 'scaled_linear', + 'variance_type': 'fixed_small', + 'clip_sample': False, + 'prediction_type': prediction_type, + 'sample_max_value': 1.0, + 'timestep_spacing': 'leading', + 'steps_offset': 1, + 'rescale_betas_zero_snr': False, + } + if train_noise_scheduler_params is not None: + train_scheduler_params.update(train_noise_scheduler_params) + noise_scheduler = DDPMScheduler(**train_scheduler_params) + + inference_scheduler_params: Dict[str, Any] = { + 'num_train_timesteps': 1000, + 'beta_start': 0.00085, + 'beta_end': 0.012, + 'beta_schedule': 'scaled_linear', + 'trained_betas': None, + 'prediction_type': prediction_type, + 'interpolation_type': 'linear', + 'use_karras_sigmas': False, + 'timestep_spacing': 'leading', + 'steps_offset': 1, + 'rescale_betas_zero_snr': False, + } + + if inference_noise_scheduler_params is not None: + inference_scheduler_params.update(inference_noise_scheduler_params) + inference_noise_scheduler = EulerDiscreteScheduler(**inference_scheduler_params) + + # Shift noise scheduler to correct for resolution changes + noise_scheduler = shift_noise_schedule(noise_scheduler, + base_dim=32, + shift_dim=scheduler_shift_resolution // downsample_factor) + inference_noise_scheduler = shift_noise_schedule(inference_noise_scheduler, + base_dim=32, + shift_dim=scheduler_shift_resolution // downsample_factor) + + # Optionally load the tokenizers and text encoders + t5_tokenizer, t5_encoder, clip_tokenizer, clip_encoder = None, None, None, None + if include_text_encoders: + dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16} + dtype = dtype_map[text_encoder_dtype] + t5_tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-xxl', cache_dir=cache_dir, local_files_only=True) + clip_tokenizer = AutoTokenizer.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', + subfolder='tokenizer', + cache_dir=cache_dir, + local_files_only=False) + t5_encoder = AutoModel.from_pretrained('google/t5-v1_1-xxl', + torch_dtype=dtype, + cache_dir=cache_dir, + local_files_only=False).encoder.eval() + clip_encoder = CLIPTextModel.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', + subfolder='text_encoder', + torch_dtype=dtype, + cache_dir=cache_dir, + local_files_only=False).cuda().eval() + # Make the composer model + model = PrecomputedTextLatentDiffusion( + unet=unet, + vae=vae, + t5_tokenizer=t5_tokenizer, + t5_encoder=t5_encoder, + clip_tokenizer=clip_tokenizer, + clip_encoder=clip_encoder, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + image_key=image_key, + t5_latent_key=t5_latent_key, + t5_mask_key=t5_mask_key, + clip_latent_key=clip_latent_key, + clip_mask_key=clip_mask_key, + clip_pooled_key=clip_pooled_key, + latent_mean=latent_mean, + latent_std=latent_std, + downsample_factor=downsample_factor, + train_metrics=train_metrics, + val_metrics=val_metrics, + quasirandomness=quasirandomness, + train_seed=train_seed, + val_seed=val_seed, + text_embed_dim=text_embed_dim, + fsdp=fsdp, + ) + + if lora_rank is not None: + assert lora_alpha is not None + model.unet.requires_grad_(False) + for param in model.unet.parameters(): + param.requires_grad_(False) + + unet_lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights='gaussian', + target_modules=['to_k', 'to_q', 'to_v', 'to_out.0'], + ) + model.unet.add_adapter(unet_lora_config) + model.unet._fsdp_wrap = True + if hasattr(model.unet, 'mid_block') and model.unet.mid_block is not None: + for attention in model.unet.mid_block.attentions: + attention._fsdp_wrap = True + for resnet in model.unet.mid_block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.up_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + for block in model.unet.down_blocks: + if hasattr(block, 'attentions'): + for attention in block.attentions: + attention._fsdp_wrap = True + if hasattr(block, 'resnets'): + for resnet in block.resnets: + resnet._fsdp_wrap = True + + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + if is_xformers_installed and use_xformers: + model.unet.enable_xformers_memory_efficient_attention() + if hasattr(model.vae, 'enable_xformers_memory_efficient_attention'): + model.vae.enable_xformers_memory_efficient_attention() + + return model + + def text_to_image_transformer( tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'), text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder'), diff --git a/diffusion/models/precomputed_text_latent_diffusion.py b/diffusion/models/precomputed_text_latent_diffusion.py new file mode 100644 index 00000000..d1ee9136 --- /dev/null +++ b/diffusion/models/precomputed_text_latent_diffusion.py @@ -0,0 +1,516 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Diffusion models.""" + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from composer.models import ComposerModel +from composer.utils import dist +from scipy.stats import qmc +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm +from transformers import PreTrainedTokenizer + +try: + import xformers # type: ignore + del xformers + is_xformers_installed = True +except: + is_xformers_installed = False + + +class PrecomputedTextLatentDiffusion(ComposerModel): + """Diffusion ComposerModel for running with precomputed T5 and CLIP embeddings. + + This is a Latent Diffusion model conditioned on text prompts that are run through + a pre-trained CLIP and T5 text encoder. + + Args: + unet (torch.nn.Module): HuggingFace conditional unet, must accept a + (B, C, H, W) input, (B,) timestep array of noise timesteps, + and (B, 77, text_embed_dim) text conditioning vectors. + vae (torch.nn.Module): HuggingFace or compatible vae. + must support `.encode()` and `decode()` functions. + noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the forward diffusion process (training). + inference_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the backward diffusion process (inference). + t5_tokenizer (Optional): Tokenizer for T5. Should only be specified during inference. Default: `None`. + t5_encoder (Optional): T5 text encoder. Should only be specified during inference. Default: `None`. + clip_tokenizer (Optional): Tokenizer for CLIP. Should only be specified during inference. Default: `None`. + clip_encoder (Optional): CLIP text encoder. Should only be specified during inference. Default: `None`. + text_embed_dim (int): The common dimension to project the text embeddings to. Default: `4096`. + prediction_type (str): The type of prediction to use. Must be one of 'sample', + 'epsilon', or 'v_prediction'. Default: `epsilon`. + image_key (str): The key in the batch dict that contains the image. Default: `'image'`. + t5_latent_key (str): The key in the batch dict that contains the T5 latents. Default: `'T5_LATENTS'`. + t5_mask_key (str): The key in the batch dict that contains the T5 attention mask. Default: `'T5_ATTENTION_MASK'`. + clip_latent_key (str): The key in the batch dict that contains the CLIP latents. Default: `'CLIP_LATENTS'`. + clip_mask_key (str): The key in the batch dict that contains the CLIP attention mask. Default: `'CLIP_ATTENTION_MASK'`. + clip_pooled_key (str): The key in the batch dict that contains the CLIP pooled embeddings. Default: `'CLIP_POOLED'`. + latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to + . Default: ``(0.0,) * 4``. + latent_std (Optional[tuple[float]]): The standard deviations of the latent space. Default: ``(1/0.13025,)*4``. + downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. + max_seq_len (int): The maximum sequence length for the text encoder. Default: `256`. + train_metrics (list): List of torchmetrics to calculate during training. + Default: `None`. + val_metrics (list): List of torchmetrics to calculate during validation. + Default: `None`. + quasirandomness (bool): Whether to use quasirandomness for generating diffusion process noise. + Default: `False`. + train_seed (int): Seed to use for generating diffusion process noise during training if using + quasirandomness. Default: `42`. + val_seed (int): Seed to use for generating eval images. Default: `1138`. + fsdp (bool): whether to use FSDP, Default: `False`. + """ + + def __init__( + self, + unet, + vae, + noise_scheduler, + inference_noise_scheduler, + t5_tokenizer: Optional[PreTrainedTokenizer] = None, + t5_encoder: Optional[torch.nn.Module] = None, + clip_tokenizer: Optional[PreTrainedTokenizer] = None, + clip_encoder: Optional[torch.nn.Module] = None, + text_embed_dim: int = 4096, + prediction_type: str = 'epsilon', + image_key: str = 'image', + t5_latent_key: str = 'T5_LATENTS', + t5_mask_key: str = 'T5_ATTENTION_MASK', + clip_latent_key: str = 'CLIP_LATENTS', + clip_mask_key: str = 'CLIP_ATTENTION_MASK', + clip_pooled_key: str = 'CLIP_POOLED', + latent_mean: Tuple[float] = (0.0,) * 4, + latent_std: Tuple[float] = (1 / 0.13025,) * 4, + downsample_factor: int = 8, + max_seq_len: int = 256, + train_metrics: Optional[List] = None, + val_metrics: Optional[List] = None, + quasirandomness: bool = False, + train_seed: int = 42, + val_seed: int = 1138, + fsdp: bool = False, + ): + super().__init__() + self.unet = unet + self.vae = vae + self.t5_tokenizer = t5_tokenizer + self.t5_encoder = t5_encoder + self.clip_tokenizer = clip_tokenizer + self.clip_encoder = clip_encoder + self.noise_scheduler = noise_scheduler + self.prediction_type = prediction_type.lower() + if self.prediction_type not in ['sample', 'epsilon', 'v_prediction']: + raise ValueError(f'prediction type must be one of sample, epsilon, or v_prediction. Got {prediction_type}') + self.downsample_factor = downsample_factor + self.max_seq_len = max_seq_len + self.quasirandomness = quasirandomness + self.train_seed = train_seed + self.val_seed = val_seed + self.image_key = image_key + self.t5_latent_key = t5_latent_key + self.t5_mask_key = t5_mask_key + self.clip_latent_key = clip_latent_key + self.clip_mask_key = clip_mask_key + self.clip_pooled_key = clip_pooled_key + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.train_metrics = train_metrics if train_metrics is not None else [MeanSquaredError()] + self.val_metrics = val_metrics if val_metrics is not None else [MeanSquaredError()] + self.inference_scheduler = inference_noise_scheduler + # freeze VAE during diffusion training + self.vae.requires_grad_(False) + self.vae = self.vae.bfloat16() + if fsdp: + # only wrap models we are training + self.vae._fsdp_wrap = False + self.unet._fsdp_wrap = True + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + if self.quasirandomness: + self.sobol = qmc.Sobol(d=1, scramble=True, seed=self.train_seed) + + # Projection layers for the text embeddings + self.clip_proj = nn.Linear(768, text_embed_dim) + self.t5_proj = nn.Linear(4096, text_embed_dim) + # Layernorms for the text embeddings + self.clip_ln = nn.LayerNorm(text_embed_dim) + self.t5_ln = nn.LayerNorm(text_embed_dim) + # Learnable position embeddings for the conitioning sequences + t5_position_embeddings = torch.randn(self.max_seq_len, text_embed_dim) + t5_position_embeddings /= math.sqrt(text_embed_dim) + self.t5_position_embedding = torch.nn.Parameter(t5_position_embeddings, requires_grad=True) + clip_position_embeddings = torch.randn(self.max_seq_len, text_embed_dim) + clip_position_embeddings /= math.sqrt(text_embed_dim) + self.clip_position_embedding = torch.nn.Parameter(clip_position_embeddings, requires_grad=True) + + def _apply(self, fn): + super(PrecomputedTextLatentDiffusion, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + + def _generate_timesteps(self, latents: torch.Tensor): + if not self.unet.training: + # Sample equally spaced timesteps across all devices + global_batch_size = latents.shape[0] * dist.get_world_size() + global_timesteps = torch.linspace(0, len(self.noise_scheduler) - 1, global_batch_size).to(torch.int64) + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = global_timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + if self.quasirandomness: + # Generate a quasirandom sequence of timesteps equal to the global batch size + global_batch_size = latents.shape[0] * dist.get_world_size() + sampled_fractions = torch.tensor(self.sobol.random(global_batch_size), device=latents.device) + timesteps = (len(self.noise_scheduler) * sampled_fractions).squeeze() + timesteps = torch.floor(timesteps).long() + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * latents.shape[0] + timesteps = timesteps[idx_offset:idx_offset + latents.shape[0]].to(latents.device) + else: + timesteps = torch.randint(0, + len(self.noise_scheduler), (latents.shape[0],), + device=latents.device, + generator=self.rng_generator) + return timesteps + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def encode_images(self, inputs, dtype=torch.bfloat16): + with torch.amp.autocast('cuda', enabled=False): + latents = self.vae.encode(inputs.to(dtype))['latent_dist'].sample().data + latents = (latents - self.latent_mean) / self.latent_std # scale latents + return latents + + def decode_latents(self, latents): + latents = latents * self.latent_std + self.latent_mean + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image + + def encode_text(self, text, device): + assert self.t5_tokenizer is not None and self.t5_encoder is not None + assert self.clip_tokenizer is not None and self.clip_encoder is not None + # Encode with T5 + t5_tokenizer_out = self.t5_tokenizer(text, + padding='max_length', + max_length=self.t5_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device) + t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device) + t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0] + # Encode with CLIP + clip_tokenizer_out = self.clip_tokenizer(text, + padding='max_length', + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device) + clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device) + clip_out = self.clip_encoder(input_ids=clip_tokenized_captions, + attention_mask=clip_attn_mask, + output_hidden_states=True) + clip_embed = clip_out.hidden_states[-2] + pooled_embeddings = clip_out[1] + return t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_embeddings + + def prepare_text_embeddings(self, t5_embed: torch.Tensor, clip_embed: torch.Tensor, t5_mask: torch.Tensor, + clip_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if t5_embed.shape[1] > self.max_seq_len: + t5_embed = t5_embed[:, :self.max_seq_len] + t5_mask = t5_mask[:, :self.max_seq_len] + if clip_embed.shape[1] > self.max_seq_len: + clip_embed = clip_embed[:, :self.max_seq_len] + clip_mask = clip_mask[:, :self.max_seq_len] + t5_embed = self.t5_proj(t5_embed) + clip_embed = self.clip_proj(clip_embed) + # Add position embeddings + t5_embed = 0.707 * t5_embed + 0.707 * self.t5_position_embedding[:t5_embed.shape[1]].unsqueeze(0) + clip_embed = 0.707 * clip_embed + 0.707 * self.clip_position_embedding[:clip_embed.shape[1]].unsqueeze(0) + # Apply layernorms + t5_embed = self.t5_ln(t5_embed) + clip_embed = self.clip_ln(clip_embed) + # Concatenate the text embeddings + text_embeds = torch.cat([t5_embed, clip_embed], dim=1) + encoder_attention_mask = torch.cat([t5_mask, clip_mask], dim=1) + return text_embeds, encoder_attention_mask + + def diffusion_forward(self, latents, timesteps): + # Add noise to the inputs (forward diffusion) + noise = torch.randn(*latents.shape, device=latents.device, generator=self.rng_generator) + noised_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + # Generate the targets + if self.prediction_type == 'epsilon': + targets = noise + elif self.prediction_type == 'sample': + targets = latents + elif self.prediction_type == 'v_prediction': + targets = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') + return noised_latents, targets + + def forward(self, batch): + latents, text_embeds, text_pooled_embeds, encoder_attention_mask = None, None, None, None + + # Encode the images with the autoencoder encoder + inputs = batch[self.image_key] + latents = self.encode_images(inputs) + + # Text embeddings are shape (B, seq_len, emb_dim), optionally truncate to a max length + t5_embed = batch[self.t5_latent_key] + t5_mask = batch[self.t5_mask_key] + clip_embed = batch[self.clip_latent_key] + clip_mask = batch[self.clip_mask_key] + text_pooled_embeds = batch[self.clip_pooled_key] + text_embeds, encoder_attention_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_mask, clip_mask) + + # Sample the diffusion timesteps + timesteps = self._generate_timesteps(latents) + noised_latents, targets = self.diffusion_forward(latents, timesteps) + + # Prepare added time ids & embeddings + add_time_ids = torch.cat( + [batch['cond_original_size'], batch['cond_crops_coords_top_left'], batch['cond_target_size']], dim=1) + added_cond_kwargs = {'text_embeds': text_pooled_embeds, 'time_ids': add_time_ids} + + # Forward through the model + return self.unet(noised_latents, + timesteps, + text_embeds, + encoder_attention_mask=encoder_attention_mask, + added_cond_kwargs=added_cond_kwargs)['sample'], targets, timesteps + + def loss(self, outputs, batch): + """Loss between unet output and added noise, typically mse.""" + loss = F.mse_loss(outputs[0], outputs[1]) + return loss + + def eval_forward(self, batch, outputs=None): + """For stable diffusion, eval forward computes unet outputs as well as some samples.""" + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics = self.train_metrics + else: + metrics = self.val_metrics + metrics_dict = {metric.__class__.__name__: metric for metric in metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + metric.update(outputs[0], outputs[1]) + + @torch.no_grad() + def generate( + self, + prompt: Optional[list] = None, + negative_prompt: Optional[list] = None, + prompt_embeds: Optional[torch.Tensor] = None, + pooled_prompt: Optional[torch.Tensor] = None, + prompt_mask: Optional[torch.Tensor] = None, + neg_prompt_embeds: Optional[torch.Tensor] = None, + pooled_neg_prompt: Optional[torch.Tensor] = None, + neg_prompt_mask: Optional[torch.Tensor] = None, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3.0, + rescaled_guidance: Optional[float] = None, + num_images_per_prompt: int = 1, + seed: Optional[int] = None, + progress_bar: bool = True, + crop_params: Optional[torch.Tensor] = None, + input_size_params: Optional[torch.Tensor] = None, + ): + """Generates image from noise. + + Performs the backward diffusion process, each inference step takes + one forward pass through the unet. + + Args: + prompt (List[str]): The prompts to guide the image generation. Only use if not + using embeddings. Default: `None`. + negative_prompt (str or List[str]): The prompt or prompts to guide the + image generation away from. Ignored when not using guidance + (i.e., ignored if guidance_scale is less than 1). Must be the same length + as list of prompts. Only use if not using negative embeddings. Default: `None`. + prompt_embeds (torch.Tensor): Optionally pass pre-tokenized prompts instead + of string prompts. Default: `None`. + pooled_prompt (torch.Tensor): Optionally pass a precomputed pooled prompt embedding + if using embeddings. Default: `None`. + prompt_mask (torch.Tensor): Optionally pass a precomputed attention mask for the + prompt embeddings. Default: `None`. + neg_prompt_embeds (torch.Tensor): Optionally pass pre-embedded negative + prompts instead of string negative prompts. Default: `None`. + pooled_neg_prompt (torch.Tensor): Optionally pass a precomputed pooled negative + prompt embedding if using embeddings. Default: `None`. + neg_prompt_mask (torch.Tensor): Optionally pass a precomputed attention mask for the + negative prompt embeddings. Default: `None`. + height (int, optional): The height in pixels of the generated image. + Default: `1024`. + width (int, optional): The width in pixels of the generated image. + Default: `1024`. + num_inference_steps (int): The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense + of slower inference. Default: `50`. + guidance_scale (float): Guidance scale as defined in + Classifier-Free Diffusion Guidance. guidance_scale is defined as w of equation + 2. of Imagen Paper. Guidance scale is enabled by setting guidance_scale > 1. + Higher guidance scale encourages to generate images that are closely linked + to the text prompt, usually at the expense of lower image quality. + Default: `3.0`. + rescaled_guidance (float, optional): Rescaled guidance scale. If not specified, rescaled guidance will + not be used. Default: `None`. + num_images_per_prompt (int): The number of images to generate per prompt. + Default: `1`. + progress_bar (bool): Whether to use the tqdm progress bar during generation. + Default: `True`. + seed (int): Random seed to use for generation. Set a seed for reproducible generation. + Default: `None`. + crop_params (torch.Tensor of size [Bx2], optional): Crop parameters to use + when generating images with SDXL. Default: `None`. + input_size_params (torch.Tensor of size [Bx2], optional): Size parameters + (representing original size of input image) to use when generating images with SDXL. + Default: `None`. + """ + # Create rng for the generation + device = self.vae.device + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) + + # Check that inputs are consistent with all embeddings or text inputs. All embeddings should be provided if using + # embeddings, and none if using text. + if (prompt_embeds is None) == (prompt is None): + raise ValueError('One and only one of prompt or prompt_embeds should be provided.') + if (pooled_prompt is None) != (prompt_embeds is None): + raise ValueError('pooled_prompt should be provided if and only if using embeddings') + if (prompt_mask is None) != (prompt_embeds is None): + raise ValueError('prompt_mask should be provided if and only if using embeddings') + if (neg_prompt_mask is None) != (neg_prompt_embeds is None): + raise ValueError('neg_prompt_mask should be provided if and only if using embeddings') + if (pooled_neg_prompt is None) != (neg_prompt_embeds is None): + raise ValueError('pooled_neg_prompt should be provided if and only if using embeddings') + + # If the prompt is specified as text, encode it. + if prompt is not None: + t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_prompt = self.encode_text( + prompt, self.vae.device) + prompt_embeds, prompt_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_attn_mask, + clip_attn_mask) + # If negative prompt is specified as text, encode it. + if negative_prompt is not None: + t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_neg_prompt = self.encode_text( + negative_prompt, self.vae.device) + neg_prompt_embeds, neg_prompt_mask = self.prepare_text_embeddings(t5_embed, clip_embed, t5_attn_mask, + clip_attn_mask) + + text_embeddings = _duplicate_tensor(prompt_embeds, num_images_per_prompt) + pooled_embeddings = _duplicate_tensor(pooled_prompt, num_images_per_prompt) + encoder_attn_mask = _duplicate_tensor(prompt_mask, num_images_per_prompt) + + batch_size = len(text_embeddings) # len prompts * num_images_per_prompt + # classifier free guidance + negative prompts + # negative prompt is given in place of the unconditional input in classifier free guidance + if neg_prompt_embeds is None: + # Negative prompt is empty and we want to zero it out + neg_prompt_embeds = torch.zeros_like(text_embeddings) + pooled_neg_prompt = torch.zeros_like(pooled_embeddings) + neg_prompt_mask = torch.zeros_like(encoder_attn_mask) + else: + neg_prompt_embeds = _duplicate_tensor(neg_prompt_embeds, num_images_per_prompt) + pooled_neg_prompt = _duplicate_tensor(pooled_neg_prompt, num_images_per_prompt) + neg_prompt_mask = _duplicate_tensor(neg_prompt_mask, num_images_per_prompt) + + # concat uncond + prompt + text_embeddings = torch.cat([neg_prompt_embeds, text_embeddings]) + pooled_embeddings = torch.cat([pooled_neg_prompt, pooled_embeddings]) + encoder_attn_mask = torch.cat([neg_prompt_mask, encoder_attn_mask]) + + # prepare for diffusion generation process + latents = torch.randn( + (batch_size, self.unet.config.in_channels, height // self.downsample_factor, + width // self.downsample_factor), + device=device, + dtype=self.unet.dtype, + generator=rng_generator, + ) + + self.inference_scheduler.set_timesteps(num_inference_steps) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.inference_scheduler.init_noise_sigma + + added_cond_kwargs = {} + # if using SDXL, prepare added time ids & embeddings + + if crop_params is None: + crop_params = torch.zeros((batch_size, 2), dtype=text_embeddings.dtype) + if input_size_params is None: + input_size_params = torch.tensor([[width, height]] * batch_size, dtype=text_embeddings.dtype) + output_size_params = torch.tensor([[width, height]] * batch_size, dtype=text_embeddings.dtype) + + crop_params = torch.cat([crop_params, crop_params]) + input_size_params = torch.cat([input_size_params, input_size_params]) + output_size_params = torch.cat([output_size_params, output_size_params]) + + add_time_ids = torch.cat([input_size_params, crop_params, output_size_params], dim=1).to(device) + added_cond_kwargs = {'text_embeds': pooled_embeddings, 'time_ids': add_time_ids} + + # backward diffusion process + for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.inference_scheduler.scale_model_input(latent_model_input, t) + # Make timestep + timestep = t.unsqueeze(0).repeat(latent_model_input.shape[0]).to(device) + # Model prediction + pred = self.unet(latent_model_input, + timestep, + encoder_hidden_states=text_embeddings, + encoder_attention_mask=encoder_attn_mask, + added_cond_kwargs=added_cond_kwargs).sample + + # perform guidance. Note this is only techincally correct for prediction_type 'epsilon' + pred_uncond, pred_text = pred.chunk(2) + pred = pred_uncond + guidance_scale * (pred_text - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_text, dim=(1, 2, 3), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2, 3), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) + # compute the previous noisy sample x_t -> x_t-1 + latents = self.inference_scheduler.step(pred, t, latents, generator=rng_generator).prev_sample + + # We now use the vae to decode the generated latents back into the image. + # scale and decode the image latents with vae + image = self.decode_latents(latents) + return image.detach().float() # (batch*num_images_per_prompt, channel, h, w) + + +def _duplicate_tensor(tensor, num_images_per_prompt): + """Duplicate tensor for multiple generations from a single prompt.""" + batch_size, seq_len = tensor.shape[:2] + tensor = tensor.repeat(1, num_images_per_prompt, *[ + 1, + ] * len(tensor.shape[2:])) + return tensor.view(batch_size * num_images_per_prompt, seq_len, *[ + -1, + ] * len(tensor.shape[2:]))