From ef74f2b3384db1fd1986d58348ca537511b77076 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Fri, 26 Jul 2024 16:26:07 -0700 Subject: [PATCH] MMDiT implementation and text-to-image training with rectified flows (#155) --- diffusion/inference/__init__.py | 4 +- diffusion/inference/inference_model.py | 78 ++++ diffusion/models/__init__.py | 4 +- diffusion/models/models.py | 132 +++++++ diffusion/models/t2i_transformer.py | 459 ++++++++++++++++++++++++ diffusion/models/transformer.py | 478 +++++++++++++++++++++++++ diffusion/train.py | 32 +- tests/test_transformer.py | 94 +++++ 8 files changed, 1277 insertions(+), 4 deletions(-) create mode 100644 diffusion/models/t2i_transformer.py create mode 100644 diffusion/models/transformer.py create mode 100644 tests/test_transformer.py diff --git a/diffusion/inference/__init__.py b/diffusion/inference/__init__.py index 34dc748c..00acafed 100644 --- a/diffusion/inference/__init__.py +++ b/diffusion/inference/__init__.py @@ -3,6 +3,6 @@ """Inference endpoint.""" -from diffusion.inference.inference_model import StableDiffusionInference, StableDiffusionXLInference +from diffusion.inference.inference_model import ModelInference, StableDiffusionInference, StableDiffusionXLInference -__all__ = ['StableDiffusionInference', 'StableDiffusionXLInference'] +__all__ = ['ModelInference', 'StableDiffusionInference', 'StableDiffusionXLInference'] diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 68fa86ee..d6925baf 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -11,6 +11,7 @@ from composer.utils.file_helpers import get_file from PIL import Image +import diffusion.models from diffusion.models import stable_diffusion_2, stable_diffusion_xl # Local checkpoint params @@ -225,3 +226,80 @@ def predict(self, model_requests: List[Dict[str, Any]]): base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') png_images.append(base64_encoded_image) return png_images + + +class ModelInference(): + """Generic inference endpoint class for diffusion models with a model.generate() method. + + Args: + model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'. + local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'. + strict (bool): Whether to load the model weights strictly. Default: False. + **model_kwargs: Keyword arguments to pass to the model initialization. + """ + + def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs): + self.device = torch.cuda.current_device() + model_factory = getattr(diffusion.models, model_name) + model = model_factory(**model_kwargs) + + if 'pretrained' in model_kwargs and model_kwargs['pretrained']: + pass + else: + state_dict = torch.load(local_checkpoint_path) + for key in list(state_dict['state']['model'].keys()): + if 'val_metrics.' in key: + del state_dict['state']['model'][key] + model.load_state_dict(state_dict['state']['model'], strict=strict) + model.to(self.device) + self.model = model.eval() + + def predict(self, model_requests: List[Dict[str, Any]]): + prompts = [] + negative_prompts = [] + generate_kwargs = {} + + # assumes the same generate_kwargs across all samples + for req in model_requests: + if 'input' not in req: + raise RuntimeError('"input" must be provided to generate call') + inputs = req['input'] + + # Prompts and negative prompts if available + if isinstance(inputs, str): + prompts.append(inputs) + elif isinstance(inputs, Dict): + if 'prompt' not in inputs: + raise RuntimeError('"prompt" must be provided to generate call if using a dict as input') + prompts.append(inputs['prompt']) + if 'negative_prompt' in inputs: + negative_prompts.append(inputs['negative_prompt']) + else: + raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}') + + generate_kwargs = req['parameters'] + + # Check for prompts + if len(prompts) == 0: + raise RuntimeError('No prompts provided, must be either a string or dictionary with "prompt"') + + # Check negative prompt length + if len(negative_prompts) == 0: + negative_prompts = None + elif len(prompts) != len(negative_prompts): + raise RuntimeError('There must be the same number of negative prompts as prompts.') + + # Generate images + with torch.cuda.amp.autocast(True): + imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu() + + # Send as bytes + png_images = [] + for i in range(imgs.shape[0]): + img = (imgs[i].permute(1, 2, 0).numpy() * 255).round().astype('uint8') + pil_image = Image.fromarray(img, 'RGB') + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') + png_images.append(base64_encoded_image) + return png_images diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 0928c83a..1a4bb0a2 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -4,7 +4,8 @@ """Diffusion models.""" from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion, - discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl) + discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, + text_to_image_transformer) from diffusion.models.noop import NoOpModel from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion @@ -19,4 +20,5 @@ 'stable_diffusion_2', 'stable_diffusion_xl', 'StableDiffusion', + 'text_to_image_transformer', ] diff --git a/diffusion/models/models.py b/diffusion/models/models.py index eaaee630..4326cf1e 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -4,6 +4,7 @@ """Constructors for diffusion models.""" import logging +import math from typing import List, Optional, Tuple, Union import torch @@ -17,7 +18,9 @@ from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion +from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer +from diffusion.models.transformer import DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -496,6 +499,135 @@ def stable_diffusion_xl( return model +def text_to_image_transformer( + tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'), + text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder'), + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + num_layers: int = 28, + max_image_side: int = 1280, + conditioning_features: int = 768, + conditioning_max_sequence_length: int = 77, + patch_size: int = 2, + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + timestep_mean: float = 0.0, + timestep_std: float = 1.0, + timestep_shift: float = 1.0, + image_key: str = 'image', + caption_key: str = 'captions', + caption_mask_key: str = 'attention_mask', + pretrained: bool = False, +): + """Text to image transformer training setup. + + Args: + tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer')``. + text_encoder_names (str, Tuple[str, ...]): HuggingFace name(s) of the text encoder(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/text_encoder')``. + vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by + this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`. + max_image_side (int): Maximum side length of the image. Default: `1280`. + conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. + conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. + patch_size (int): Patch size for the transformer. Default: `2`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + timestep_mean (float): The mean of the timesteps. Default: `0.0`. + timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. + timestep_shift (float): The shift of the timesteps. Default: `1.0`. + image_key (str): The key for the image in the batch. Default: `image`. + caption_key (str): The key for the captions in the batch. Default: `captions`. + caption_mask_key (str): The key for the caption mask in the batch. Default: `attention_mask`. + pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False. + """ + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if (isinstance(tokenizer_names, tuple) or + isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names): + raise ValueError('Number of tokenizer_names and text_encoder_names must be equal') + + # Make the tokenizer and text encoder + tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names) + text_encoder = MultiTextEncoder(model_names=text_encoder_names, encode_latents_in_fp16=True, pretrained_sdxl=False) + + precision = torch.float16 + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + autoencoder_channels = 4 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + autoencoder_channels = vae.config['latent_channels'] + assert isinstance(vae, torch.nn.Module) + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * autoencoder_channels + if isinstance(latent_std, float): + latent_std = (latent_std,) * autoencoder_channels + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) + # Figure out the maximum input sequence length + input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size)) + # Make the transformer model + transformer = DiffusionTransformer(num_features=64 * num_layers, + num_heads=num_layers, + num_layers=num_layers, + input_features=autoencoder_channels * (patch_size**2), + input_max_sequence_length=input_max_sequence_length, + input_dimension=2, + conditioning_features=conditioning_features, + conditioning_max_sequence_length=conditioning_max_sequence_length, + conditioning_dimension=1, + expansion_factor=4) + # Make the composer model + model = ComposerTextToImageMMDiT(model=transformer, + autoencoder=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + latent_mean=latent_mean, + latent_std=latent_std, + patch_size=patch_size, + downsample_factor=downsample_factor, + latent_channels=autoencoder_channels, + timestep_mean=timestep_mean, + timestep_std=timestep_std, + timestep_shift=timestep_shift, + image_key=image_key, + caption_key=caption_key, + caption_mask_key=caption_mask_key) + + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + return model + + def build_autoencoder(input_channels: int = 3, output_channels: int = 3, hidden_channels: int = 128, diff --git a/diffusion/models/t2i_transformer.py b/diffusion/models/t2i_transformer.py new file mode 100644 index 00000000..ab623d16 --- /dev/null +++ b/diffusion/models/t2i_transformer.py @@ -0,0 +1,459 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Composer model for text to image generation with a multimodal transformer.""" + +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from composer.models import ComposerModel +from composer.utils import dist +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm + +from diffusion.models.transformer import DiffusionTransformer, VectorEmbedding + + +def patchify(latents: torch.Tensor, patch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to extract non-overlapping patches from image-like latents. + + Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size]. + Coordinates of the patches are also returned to allow for unpatching and sequence embedding. + + Args: + latents (torch.Tensor): Latents of shape [B, C, H, W]. + patch_size (int): Size of the patches. + + Returns: + torch.Tensor: Patches of shape [B, num_patches, C * patch_size * patch_size]. + torch.Tensor: Coordinates of the patches. Shape [B, num_patches, 2]. + """ + # Assume img is a tensor of shape [B, C, H, W] + B, C, H, W = latents.shape + assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' + # Reshape and permute to get non-overlapping patches + num_H_patches = H // patch_size + num_W_patches = W // patch_size + patches = latents.reshape(B, C, num_H_patches, patch_size, num_W_patches, patch_size) + patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * patch_size * patch_size) + # Generate coordinates for each patch + coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + return patches, coords + + +def unpatchify(patches: torch.Tensor, coords: torch.Tensor, patch_size: int) -> torch.Tensor: + """Recover an image-like tensor from a sequence of patches and their coordinates. + + Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W]. + Coordinates are used to place the patches in the correct location in the image. + + Args: + patches (torch.Tensor): Patches of shape [num_patches, C * patch_size * patch_size]. + coords (torch.Tensor): Coordinates of the patches. Shape [num_patches, 2]. + patch_size (int): Size of the patches. + """ + # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] + C = patches.shape[1] // (patch_size * patch_size) + # Calculate the height and width of the original image from the coordinates + H = coords[:, 0].max() * patch_size + patch_size + W = coords[:, 1].max() * patch_size + patch_size + # Initialize an empty tensor for the reconstructed image + img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) # type: ignore + # Iterate over the patches and their coordinates + for patch, (y, x) in zip(patches, patch_size * coords): + # Reshape the patch to [C, patch_size, patch_size] + patch = patch.view(C, patch_size, patch_size) + # Place the patch in the corresponding location in the image + img[:, y:y + patch_size, x:x + patch_size] = patch + return img + + +class ComposerTextToImageMMDiT(ComposerModel): + """ComposerModel for text to image with a diffusion transformer. + + Args: + model (DiffusionTransformer): Core diffusion model. + autoencoder (torch.nn.Module): HuggingFace or compatible vae. + must support `.encode()` and `decode()` functions. + text_encoder (torch.nn.Module): HuggingFace CLIP or LLM text enoder. + tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for + text_encoder. For a `CLIPTextModel` this will be the + `CLIPTokenizer` from HuggingFace transformers. + latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to + 4 * (0.0,). Default: `None`. + latent_std (Optional[tuple[float]]): The standard deviations of the latent space. If not specified, + defaults to 4 * (1/0.13025,). Default: `None`. + patch_size (int): The size of the patches in the image latents. Default: `2`. + downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. + latent_channels (int): The number of channels in the autoencoder latent space. Default: `4`. + timestep_mean (float): The mean of the logit-normal distribution for sampling timesteps. Default: `0.0`. + timestep_std (float): The standard deviation of the logit-normal distribution for sampling timesteps. + Default: `1.0`. + timestep_shift (float): The shift parameter for the logit-normal distribution for sampling timesteps. + A value of `1.0` is no shift. Default: `1.0`. + image_key (str): The name of the images in the dataloader batch. Default: `image`. + caption_key (str): The name of the caption in the dataloader batch. Default: `caption`. + caption_mask_key (str): The name of the caption mask in the dataloader batch. Default: `caption_mask`. + pooled_embedding_features (int): The number of features in the pooled text embeddings. Default: `768`. + """ + + def __init__( + self, + model: DiffusionTransformer, + autoencoder: torch.nn.Module, + text_encoder: torch.nn.Module, + tokenizer, + latent_mean: Optional[tuple[float]] = None, + latent_std: Optional[tuple[float]] = None, + patch_size: int = 2, + downsample_factor: int = 8, + latent_channels: int = 4, + timestep_mean: float = 0.0, + timestep_std: float = 1.0, + timestep_shift: float = 1.0, + image_key: str = 'image', + caption_key: str = 'caption', + caption_mask_key: str = 'caption_mask', + pooled_embedding_features: int = 768, + ): + super().__init__() + self.model = model + self.autoencoder = autoencoder + self.text_encoder = text_encoder + self.tokenizer = tokenizer + if latent_mean is None: + self.latent_mean = 4 * (0.0) + if latent_std is None: + self.latent_std = 4 * (1 / 0.18215,) + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.patch_size = patch_size + self.downsample_factor = downsample_factor + self.latent_channels = latent_channels + self.timestep_mean = timestep_mean + self.timestep_std = timestep_std + self.timestep_shift = timestep_shift + self.image_key = image_key + self.caption_key = caption_key + self.caption_mask_key = caption_mask_key + self.pooled_embedding_features = pooled_embedding_features + + # Embedding MLP for the pooled text embeddings + self.pooled_embedding_mlp = VectorEmbedding(pooled_embedding_features, model.num_features) + + # freeze text_encoder during diffusion training and use half precision + self.autoencoder.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.autoencoder = self.autoencoder.half() + self.text_encoder = self.text_encoder.half() + + # Only FSDP wrap models we are training + self.model._fsdp_wrap = True + self.autoencoder._fsdp_wrap = False + self.text_encoder._fsdp_wrap = True + + # Param counts relevant for MFU computation + # First calc the AdaLN params separately + self.adaLN_params = sum(p.numel() for n, p in self.model.named_parameters() if 'adaLN_mlp_linear' in n) + # For MFU calc we must be careful to prevent double counting of MMDiT flops. + # Here, count the number of params applied to each sequence element. + # Last block must be handled differently since post attn layers don't run on conditioning sequence + self.n_seq_params_per_block = self.model.num_features**2 * (4 + 2 * self.model.expansion_factor) + self.n_seq_params = self.n_seq_params_per_block * (self.model.num_layers - 1) + self.n_seq_params += 3 * (self.model.num_features**2) + self.n_last_layer_params = self.model.num_features**2 * (1 + 2 * self.model.expansion_factor) + # Params only on the input sequence + self.n_input_params = self.model.input_features * self.model.num_features + # Params only on the conditioning sequence + self.n_cond_params = self.model.conditioning_features * self.model.num_features + + # Set up metrics + self.train_metrics = [MeanSquaredError()] + self.val_metrics = [MeanSquaredError()] + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + + def _apply(self, fn): + super(ComposerTextToImageMMDiT, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def flops_per_batch(self, batch) -> int: + batch_size = batch[self.image_key].shape[0] + height, width = batch[self.image_key].shape[2:] + input_seq_len = height * width / (self.patch_size**2 * self.downsample_factor**2) + cond_seq_len = batch[self.caption_key].shape[1] + seq_len = input_seq_len + cond_seq_len + # Calulate forward flops on full sequence excluding attention + param_flops = 2 * self.n_seq_params * batch_size * seq_len + # Last block contributes a bit less than other blocks + param_flops += 2 * self.n_last_layer_params * batch_size * input_seq_len + # Include input sequence params (comparatively small) + param_flops += 2 * self.n_input_params * batch_size * input_seq_len + # Include conditioning sequence params (comparatively small) + param_flops += 2 * self.n_cond_params * batch_size * cond_seq_len + # Include flops from adaln + param_flops += 2 * self.adaLN_params * batch_size + # Calculate flops for attention layers + attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size + return 3 * param_flops + 3 * attention_flops + + def encode_image(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode an image tensor with the autoencoder and patchify the latents.""" + with torch.cuda.amp.autocast(enabled=False): + latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data + # Scale and patchify the latents + latents = (latents - self.latent_mean) / self.latent_std + latent_patches, latent_coords = patchify(latents, self.patch_size) + return latent_patches, latent_coords + + @torch.no_grad() + def decode_image(self, latent_patches: torch.Tensor, latent_coords: torch.Tensor) -> torch.Tensor: + """Decode image latent patches and unpatchify the image.""" + # Unpatchify the latents + latents = [ + unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) + ] + latents = torch.stack(latents) + # Scale the latents back to the original scale + latents = latents * self.latent_std + self.latent_mean + # Decode the latents + with torch.cuda.amp.autocast(enabled=False): + image = self.autoencoder.decode(latents.half()).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image + + def tokenize_prompts(self, prompts: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]: + """Tokenize the prompts using the model's tokenizer.""" + tokenized_out = self.tokenizer(prompts, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + return tokenized_out['input_ids'], tokenized_out['attention_mask'] + + def combine_attention_masks(self, attention_masks: torch.Tensor) -> torch.Tensor: + """Combine attention masks for the encoder if there are multiple text encoders.""" + if len(attention_masks.shape) == 2: + return attention_masks + elif len(attention_masks.shape) == 3: + encoder_attention_masks = attention_masks[:, 0] + for i in range(1, attention_masks.shape[1]): + encoder_attention_masks |= attention_masks[:, i] + return encoder_attention_masks + else: + raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_masks.shape}') + + def make_text_embeddings_coords(self, text_embeddings: torch.Tensor) -> torch.Tensor: + """Make text embeddings coordinates for the transformer.""" + text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) + text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) + text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) + return text_embeddings_coords + + def embed_tokenized_prompts( + self, tokenized_prompts: torch.Tensor, + attention_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Use the model's text encoder to embed tokenized prompts and create pooled text embeddings.""" + with torch.cuda.amp.autocast(enabled=False): + # Ensure text embeddings are not longer than the model can handle + if tokenized_prompts.shape[1] > self.model.conditioning_max_sequence_length: + tokenized_prompts = tokenized_prompts[:, :self.model.conditioning_max_sequence_length] + text_encoder_out = self.text_encoder(tokenized_prompts, attention_mask=attention_masks) + text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] + text_mask = self.combine_attention_masks(attention_masks) + text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) + # Ensure the embeddings are the same dtype as the model + text_embeddings = text_embeddings.to(next(self.model.parameters()).dtype) + pooled_text_embeddings = pooled_text_embeddings.to(next(self.pooled_embedding_mlp.parameters()).dtype) + # Encode the pooled embeddings + pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) + return text_embeddings, text_embeddings_coords, text_mask, pooled_text_embeddings + + def diffusion_forward_process(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Diffusion forward process using a rectified flow.""" + if not self.model.training: + # Sample equally spaced timesteps across all devices + global_batch_size = inputs.shape[0] * dist.get_world_size() + global_timesteps = torch.linspace(0, 1, global_batch_size) + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * inputs.shape[0] + timesteps = global_timesteps[idx_offset:idx_offset + inputs.shape[0]].to(inputs.device) + timesteps = timesteps.view(-1, 1, 1) + else: + # Sample timesteps according to a logit-normal distribution + u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator, dtype=inputs.dtype) + u = self.timestep_mean + self.timestep_std * u + timesteps = torch.sigmoid(u).view(-1, 1, 1) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + # Then, add the noise to the latents according to the recitified flow + noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) + noised_inputs = (1 - timesteps) * inputs + timesteps * noise + # Compute the targets, which are the velocities + targets = noise - inputs + return noised_inputs, targets, timesteps[:, 0, 0] + + def forward(self, batch): + # Get the inputs + image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] + # Get the image latents + latent_patches, latent_coords = self.encode_image(image) + # Get the text embeddings and their coords + text_embeddings, text_embeddings_coords, caption_mask, pooled_text_embeddings = self.embed_tokenized_prompts( + caption, caption_mask) + # Diffusion forward process + noised_inputs, targets, timesteps = self.diffusion_forward_process(latent_patches) + # Forward through the model + model_out = self.model(noised_inputs, + latent_coords, + timesteps, + conditioning=text_embeddings, + conditioning_coords=text_embeddings_coords, + input_mask=None, + conditioning_mask=caption_mask, + constant_conditioning=pooled_text_embeddings) + return {'predictions': model_out, 'targets': targets, 'timesteps': timesteps} + + def loss(self, outputs, batch): + """MSE loss between outputs and targets.""" + loss = F.mse_loss(outputs['predictions'], outputs['targets']) + return loss + + def eval_forward(self, batch, outputs=None): + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics_dict = {metric.__class__.__name__: metric for metric in self.train_metrics} + else: + metrics_dict = {metric.__class__.__name__: metric for metric in self.val_metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + if isinstance(metric, MeanSquaredError): + metric.update(outputs['predictions'], outputs['targets']) + else: + raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') + + def make_sampling_timesteps(self, N: int): + timesteps = torch.linspace(1, 0, N + 1) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + # Make timestep differences + delta_t = timesteps[:-1] - timesteps[1:] + return timesteps[:-1], delta_t + + @torch.no_grad() + def generate(self, + prompt: Union[str, list], + negative_prompt: Optional[Union[str, list]] = None, + height: int = 256, + width: int = 256, + guidance_scale: float = 7.0, + rescaled_guidance: Optional[float] = None, + num_inference_steps: int = 50, + num_images_per_prompt: int = 1, + progress_bar: bool = True, + seed: Optional[int] = None): + """Run generation for the model. + + Args: + prompt (str, list): Prompt or prompts for the generation. + negative_prompt (Optional[str, list]): Negative prompt or prompts for the generation. Default: `None`. + height (int): Height of the generated images. Default: `256`. + width (int): Width of the generated images. Default: `256`. + guidance_scale (float): Scale for the guidance. Default: `7.0`. + rescaled_guidance (Optional[float]): Rescale the guidance. Default: `None`. + num_inference_steps (int): Number of inference steps. Default: `50`. + num_images_per_prompt (int): Number of images per prompt. Default: `1`. + progress_bar (bool): Whether to show a progress bar. Default: `True`. + seed (Optional[int]): Seed for the generation. Default: `None`. + + Returns: + torch.Tensor: Generated images. Shape [batch*num_images_per_prompt, channel, h, w]. + """ + device = next(self.model.parameters()).device + # Create rng for the generation + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) + + # Set default negative prompts to empty string if not provided + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) + elif isinstance(negative_prompt, list): + assert len(negative_prompt) == len(prompt), 'Prompt and negative prompt must have the same length.' + elif negative_prompt is None: + negative_prompt = ['' for _ in prompt] + # Duplicate the images in the prompt and negative prompt if needed. + prompt = [item for item in prompt for _ in range(num_images_per_prompt)] + negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] + # Tokenize both prompt and negative prompts + prompt_tokens, prompt_mask = self.tokenize_prompts(prompt) + prompt_tokens, prompt_mask = prompt_tokens.to(device), prompt_mask.to(device) + negative_prompt_tokens, negative_prompt_mask = self.tokenize_prompts(negative_prompt) + negative_prompt_tokens, negative_prompt_mask = negative_prompt_tokens.to(device), negative_prompt_mask.to( + device) + # Embed the tokenized prompts and negative prompts + text_embeddings, text_embeddings_coords, prompt_mask, pooled_embedding = self.embed_tokenized_prompts( + prompt_tokens, prompt_mask) + neg_text_embeddings, neg_text_embeddings_coords, neg_prompt_mask, pooled_neg_embedding = self.embed_tokenized_prompts( + negative_prompt_tokens, negative_prompt_mask) + + # Generate initial noise + latent_height = height // self.downsample_factor + latent_width = width // self.downsample_factor + latents = torch.randn(text_embeddings.shape[0], + self.latent_channels, + latent_height, + latent_width, + device=device) + latent_patches, latent_coords = patchify(latents, self.patch_size) + + # Set up for CFG + text_embeddings = torch.cat([text_embeddings, neg_text_embeddings], dim=0) + text_embeddings_coords = torch.cat([text_embeddings_coords, neg_text_embeddings_coords], dim=0) + text_embeddings_mask = torch.cat([prompt_mask, neg_prompt_mask], dim=0) + pooled_embedding = torch.cat([pooled_embedding, pooled_neg_embedding], dim=0) + latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) + + # backward diffusion process + timesteps, delta_t = self.make_sampling_timesteps(num_inference_steps) + timesteps, delta_t = timesteps.to(device), delta_t.to(device) + for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): + latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) + # Get the model prediction + model_out = self.model(latent_patches_input, + latent_coords_input, + t.unsqueeze(0), + conditioning=text_embeddings, + conditioning_coords=text_embeddings_coords, + input_mask=None, + conditioning_mask=text_embeddings_mask, + constant_conditioning=pooled_embedding) + # Do CFG + pred_cond, pred_uncond = model_out.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_cond, dim=(1, 2), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) + # Update the latents + latent_patches = latent_patches - pred * delta_t[i] + # Decode the latents + image = self.decode_image(latent_patches, latent_coords) + return image.detach() # (batch*num_images_per_prompt, channel, h, w) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py new file mode 100644 index 00000000..453208fd --- /dev/null +++ b/diffusion/models/transformer.py @@ -0,0 +1,478 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Diffusion Transformer model.""" + +import math +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Modulate the input with the shift and scale.""" + return x * (1.0 + scale) + shift + + +def get_multidimensional_position_embeddings(position_embeddings: torch.Tensor, coords: torch.Tensor) -> torch.Tensor: + """Extracts position embeddings for a multidimensional sequence given by coordinates. + + Position embeddings are shape (D, T, F). Coords are shape (B, S, D). Position embeddings should be + interpreted as D dimensional embeddings with F features each for a maximum of T timesteps. + Coordinates or `coords` is a batch of size B of sequences of length S with D dimensional integer + coordinates. For example, if D=2, then each of the B, S elements of the sequence would have a 2D + X,Y coordinate. + + Args: + position_embeddings (torch.Tensor): Position embeddings of shape (D, T, F). + coords (torch.Tensor): Coordinates of shape (B, S, D). + + Returns: + torch.Tensor: Sequenced embeddings of shape (B, S, F, D) + """ + B, S, D = coords.shape + F = position_embeddings.shape[2] + coords = coords.reshape(B * S, D) + sequenced_embeddings = [position_embeddings[d, coords[:, d]] for d in range(D)] + sequenced_embeddings = torch.stack(sequenced_embeddings, dim=-1) + sequenced_embeddings = sequenced_embeddings.view(B, S, F, D) + return sequenced_embeddings # (B, S, F, D) + + +class AdaptiveLayerNorm(nn.Module): + """Adaptive LayerNorm. + + Scales and shifts the output of a LayerNorm using an MLP conditioned on a scalar. + + Args: + num_features (int): Number of input features. + """ + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + # MLP for computing modulations. + # Initialized to zero so modulation acts as identity at initialization. + self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + # LayerNorm + self.layernorm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # Calculate the modulations + mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) + # Apply the modulations + return modulate(self.layernorm(x), mods[0], mods[1]) + + +class ModulationLayer(nn.Module): + """Modulation layer. + + Scales the input by a factor determined by a scalar input. + + Args: + num_features (int): Number of input features. + """ + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + # MLP for computing modulation. + # Initialized to zero so modulation starts off at zero. + self.adaLN_mlp_linear = nn.Linear(self.num_features, self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # Calculate the modulations + mods = self.adaLN_mlp(t).unsqueeze(1) + return x * mods + + +class ScalarEmbedding(nn.Module): + """Embedding block for scalars. + + Embeds a scalar into a vector of size `num_features` using a sinusoidal embedding followed by an MLP. + + Args: + num_features (int): The size of the output vector. + sinusoidal_embedding_dim (int): The size of the intermediate sinusoidal embedding. Default: `256`. + + Returns: + torch.Tensor: The embedded scalar + """ + + def __init__(self, num_features: int, sinusoidal_embedding_dim: int = 256): + super().__init__() + self.num_features = num_features + self.sinusoidal_embedding_dim = sinusoidal_embedding_dim + self.linear_1 = nn.Linear(self.sinusoidal_embedding_dim, self.num_features) + self.linear_2 = nn.Linear(self.num_features, self.num_features) + self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) + + @staticmethod + def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor: + """Create sinusoidal timestep embeddings. + + Args: + timesteps (torch.Tensor): The timesteps to embed. + dim (int): The size of the output embedding. + max_period (int): The maximum period of the sinusoidal embedding. Default: `10000`. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, x: torch.Tensor) -> torch.Tensor: + sinusoidal_embedding = self.timestep_embedding(x, self.sinusoidal_embedding_dim) + # Ensure embedding is the correct dtype + sinusoidal_embedding = sinusoidal_embedding.to(next(self.parameters()).dtype) + return self.mlp(sinusoidal_embedding) + + +class VectorEmbedding(nn.Module): + """Embedding block for vectors. + + Embeds vectors via an MLP into a vector of size `num_features`. + + Args: + input_features (int): The size of the input vector. + num_features (int): The size of the output vector. + """ + + def __init__(self, input_features: int, num_features: int): + super().__init__() + self.input_features = input_features + self.num_features = num_features + self.linear_1 = nn.Linear(self.input_features, self.num_features) + self.linear_2 = nn.Linear(self.num_features, self.num_features) + self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class PreAttentionBlock(nn.Module): + """Block to compute QKV before attention. + + Includes QK layernorms and an adaptive layernorms. + + Args: + num_features (int): Number of input features. + """ + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + # Adaptive layernorm + self.adaptive_layernorm = AdaptiveLayerNorm(self.num_features) + # Linear layer to get q, k, and v + self.qkv = nn.Linear(self.num_features, 3 * self.num_features) + # QK layernorms. Original MMDiT used RMSNorm here. + self.q_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + self.k_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + # Initialize all biases to zero + nn.init.zeros_(self.qkv.bias) + # Init the standard deviation of the weights to 0.02 as is tradition + nn.init.normal_(self.qkv.weight, std=0.02) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.adaptive_layernorm(x, t) + # Calculate the query, key, and values all in one go + q, k, v = self.qkv(x).chunk(3, dim=-1) + q = self.q_norm(q) + k = self.k_norm(k) + return q, k, v + + +class SelfAttention(nn.Module): + """Standard multihead self attention layer that supports masking. + + Args: + num_features (int): Number of input features. + num_heads (int): Number of attention heads. + """ + + def __init__(self, num_features: int, num_heads: int): + super().__init__() + self.num_features = num_features + self.num_heads = num_heads + + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: + # Get the shape of the inputs + B, T, C = v.size() + # Reshape the query, key, and values for multi-head attention + q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) + k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) + v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) + # Native torch attention + attention_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # (B, H, T, C/H) + # Swap the sequence length and the head dimension back and get rid of num_heads. + attention_out = attention_out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C) + return attention_out + + +class PostAttentionBlock(nn.Module): + """Block to postprocess v after attention. + + Includes adaptive layernorms. + + Args: + num_features (int): Number of input features. + expansion_factor (int): Expansion factor for the MLP. Default: `4`. + """ + + def __init__(self, num_features: int, expansion_factor: int = 4): + super().__init__() + self.num_features = num_features + self.expansion_factor = expansion_factor + # Input modulation + self.modulate_v = ModulationLayer(self.num_features) + # Linear layer to process v + self.linear_v = nn.Linear(self.num_features, self.num_features) + # Layernorm for the output + self.output_norm = AdaptiveLayerNorm(self.num_features) + # Transformer style MLP layers + self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features) + self.nonlinearity = nn.GELU(approximate='tanh') + self.linear_2 = nn.Linear(self.expansion_factor * self.num_features, self.num_features) + # Initialize all biases to zero + nn.init.zeros_(self.linear_1.bias) + nn.init.zeros_(self.linear_2.bias) + # Output MLP + self.output_mlp = nn.Sequential(self.linear_1, self.nonlinearity, self.linear_2) + # Output modulation + self.modulate_output = ModulationLayer(self.num_features) + + def forward(self, v: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Forward takes v from self attention and the original sequence x with scalar conditioning t.""" + # Postprocess v with linear + gating modulation + y = self.modulate_v(self.linear_v(v), t) + y = x + y + # Adaptive layernorm + y = self.output_norm(y, t) + # Output MLP + y = self.output_mlp(y) + # Gating modulation for the output + y = self.modulate_output(y, t) + y = x + y + return y + + +class MMDiTBlock(nn.Module): + """Transformer block that supports masking, multimodal attention, and adaptive norms. + + Can optionally be the last block in the network, in which case it does not apply post-attention layers to the + conditioning sequence, as those params may not be used. + + Args: + num_features (int): Number of input features. + num_heads (int): Number of attention heads. + expansion_factor (int): Expansion factor for the MLP. Default: `4`. + is_last (bool): Whether this is the last block in the network. Default: `False`. + """ + + def __init__(self, num_features: int, num_heads: int, expansion_factor: int = 4, is_last: bool = False): + super().__init__() + self.num_features = num_features + self.num_heads = num_heads + self.expansion_factor = expansion_factor + self.is_last = is_last + # Pre-attention blocks for two modalities + self.pre_attention_block_1 = PreAttentionBlock(self.num_features) + self.pre_attention_block_2 = PreAttentionBlock(self.num_features) + # Self-attention + self.attention = SelfAttention(self.num_features, self.num_heads) + # Post-attention blocks for two modalities + self.post_attention_block_1 = PostAttentionBlock(self.num_features, self.expansion_factor) + if not self.is_last: + self.post_attention_block_2 = PostAttentionBlock(self.num_features, self.expansion_factor) + + def forward(self, + x1: torch.Tensor, + x2: torch.Tensor, + t: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + # Pre-attention for the two modalities + q1, k1, v1 = self.pre_attention_block_1(x1, t) + q2, k2, v2 = self.pre_attention_block_2(x2, t) + # Concat q, k, v along the sequence dimension + q = torch.cat([q1, q2], dim=1) + k = torch.cat([k1, k2], dim=1) + v = torch.cat([v1, v2], dim=1) + # Self-attention + v = self.attention(q, k, v, mask=mask) + # Split the attention output back into the two modalities + seq_len_1, seq_len_2 = x1.size(1), x2.size(1) + y1, y2 = v.split([seq_len_1, seq_len_2], dim=1) + # Post-attention for the two modalities + y1 = self.post_attention_block_1(y1, x1, t) + if not self.is_last: + y2 = self.post_attention_block_2(y2, x2, t) + return y1, y2 + + +class DiffusionTransformer(nn.Module): + """Transformer model for generic diffusion. + + Supports input and conditioning sequences with different lengths and dimensions. + + Args: + num_features (int): Number of hidden features. + num_heads (int): Number of attention heads. + num_layers (int): Number of transformer layers. + input_features (int): Number of features in the input sequence. Default: `192`. + input_max_sequence_length (int): Maximum sequence length for the input sequence. Default: `1024`. + input_dimension (int): Dimension of the input sequence. Default: `2`. + conditioning_features (int): Number of features in the conditioning sequence. Default: `1024`. + conditioning_max_sequence_length (int): Maximum sequence length for the conditioning sequence. Default: `77`. + conditioning_dimension (int): Dimension of the conditioning sequence. Default: `1`. + expansion_factor (int): Expansion factor for the MLPs. Default: `4`. + """ + + def __init__(self, + num_features: int, + num_heads: int, + num_layers: int, + input_features: int = 192, + input_max_sequence_length: int = 1024, + input_dimension: int = 2, + conditioning_features: int = 1024, + conditioning_max_sequence_length: int = 77, + conditioning_dimension: int = 1, + expansion_factor: int = 4): + super().__init__() + # Params for the network architecture + self.num_features = num_features + self.num_heads = num_heads + self.num_layers = num_layers + self.expansion_factor = expansion_factor + # Params for input embeddings + self.input_features = input_features + self.input_dimension = input_dimension + self.input_max_sequence_length = input_max_sequence_length + # Params for conditioning embeddings + self.conditioning_features = conditioning_features + self.conditioning_dimension = conditioning_dimension + self.conditioning_max_sequence_length = conditioning_max_sequence_length + + # Embedding block for the timestep + self.timestep_embedding = ScalarEmbedding(self.num_features) + # Projection layer for the input sequence + self.input_embedding = nn.Linear(self.input_features, self.num_features) + # Embedding layer for the input sequence + input_position_embedding = torch.randn(self.input_dimension, self.input_max_sequence_length, self.num_features) + input_position_embedding /= math.sqrt(self.num_features) + self.input_position_embedding = torch.nn.Parameter(input_position_embedding, requires_grad=True) + # Projection layer for the conditioning sequence + self.conditioning_embedding = nn.Linear(self.conditioning_features, self.num_features) + # Embedding layer for the conditioning sequence + conditioning_position_embedding = torch.randn(self.conditioning_dimension, + self.conditioning_max_sequence_length, self.num_features) + conditioning_position_embedding /= math.sqrt(self.num_features) + self.conditioning_position_embedding = torch.nn.Parameter(conditioning_position_embedding, requires_grad=True) + # Transformer blocks + self.transformer_blocks = nn.ModuleList([ + MMDiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor) + for _ in range(self.num_layers - 1) + ]) + # Turn off post attn layers for conditioning sequence in final block + self.transformer_blocks.append( + MMDiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor, is_last=True)) + # Output projection layer + self.final_norm = AdaptiveLayerNorm(self.num_features) + self.final_linear = nn.Linear(self.num_features, self.input_features) + # Init the output layer to zero + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + + def fsdp_wrap_fn(self, module: nn.Module) -> bool: + if isinstance(module, MMDiTBlock): + return True + return False + + def activation_checkpointing_fn(self, module: nn.Module) -> bool: + if isinstance(module, MMDiTBlock): + return True + return False + + def forward(self, + x: torch.Tensor, + input_coords: torch.Tensor, + t: torch.Tensor, + conditioning: torch.Tensor, + conditioning_coords: torch.Tensor, + input_mask: Optional[torch.Tensor] = None, + conditioning_mask: Optional[torch.Tensor] = None, + constant_conditioning: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass through the diffusion transformer. + + Args: + x (torch.Tensor): The input sequence of shape (B, T1, C1). + input_coords (torch.Tensor): The coordinates of the D dimensional input sequence of shape (B, T1, D). + t (torch.Tensor): The scalar timesteps of shape (B, 1). + conditioning (torch.Tensor): The conditioning sequence of shape (B, T2, C2). + conditioning_coords (torch.Tensor): The coordinates of the D dimensional conditioning sequence of shape (B, T2, D). + input_mask (Optional[torch.Tensor]): The mask for the input sequence of shape (B, T1). + conditioning_mask (Optional[torch.Tensor]): The mask for the conditioning sequence of shape (B, T2). + constant_conditioning (Optional[torch.Tensor]): Optional additional constant conditioning (B, num_features). + + Returns: + torch.Tensor: The output sequence of shape (B, T1, C1). + """ + # Embed the timestep + t = self.timestep_embedding(t) + # Optionally add constant conditioning. This assumes it has been embedded already. + if constant_conditioning is not None: + t = t + constant_conditioning + # Embed the input + y = self.input_embedding(x) # (B, T1, C) + # Get the input position embeddings and add them to the input + y_position_embeddings = get_multidimensional_position_embeddings(self.input_position_embedding, input_coords) + y_position_embeddings = y_position_embeddings.sum(dim=-1) # (B, T1, C) + y = y + y_position_embeddings # (B, T1, C) + if input_mask is None: + mask = torch.ones(x.shape[0], x.shape[1], device=x.device) + else: + mask = input_mask + + # Embed the conditioning + c = self.conditioning_embedding(conditioning) # (B, T2, C) + # Get the conditioning position embeddings and add them to the conditioning + c_position_embeddings = get_multidimensional_position_embeddings(self.conditioning_position_embedding, + conditioning_coords) + c_position_embeddings = c_position_embeddings.sum(dim=-1) # (B, T2, C) + c = c + c_position_embeddings # (B, T2, C) + # Concatenate the masks + if conditioning_mask is None: + conditioning_mask = torch.ones(conditioning.shape[0], conditioning.shape[1], device=conditioning.device) + mask = torch.cat([mask, conditioning_mask], dim=1) # (B, T1 + T2) + + # Expand the mask to the right shape + mask = mask.bool() + mask = mask.unsqueeze(-1) & mask.unsqueeze(1) # (B, T1 + T2, T1 + T2) + identity = torch.eye(mask.shape[1], device=mask.device, dtype=mask.dtype).unsqueeze(0) + mask = mask | identity + mask = mask.unsqueeze(1) # (B, 1, T1 + T2, T1 + T2) + + # Pass through the transformer blocks + for block in self.transformer_blocks: + y, c = block(y, c, t, mask=mask) + # Pass through the output layers to get the right number of elements + y = self.final_norm(y, t) + y = self.final_linear(y) + return y diff --git a/diffusion/train.py b/diffusion/train.py index 1d6799ff..9a9dbaf5 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -20,6 +20,7 @@ from torch.optim import Optimizer from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder +from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: @@ -50,6 +51,31 @@ def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Opti return optimizer +def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: + """Configures the optimizer for use with a transformer model.""" + print('Configuring optimizer for transformer') + assert isinstance(model, ComposerTextToImageMMDiT) + + # Turn off weight decay for biases, norms, and positional embeddings. + no_decay = ['bias', 'norm', 'position_embedding'] + params_with_no_decay = [] + params_with_decay = [] + for name, param in model.named_parameters(): + if any(nd in name for nd in no_decay): + params_with_no_decay.append(param) + else: + params_with_decay.append(param) + no_decay_dict = dict(config.optimizer.items()) + no_decay_dict['params'] = params_with_no_decay + no_decay_dict['weight_decay'] = 0.0 + + decay_dict = dict(config.optimizer.items()) + decay_dict['params'] = params_with_decay + + optimizer = hydra.utils.instantiate(config.optimizer, [no_decay_dict, decay_dict]) + return optimizer + + def train(config: DictConfig) -> None: """Train a model. @@ -62,10 +88,14 @@ def train(config: DictConfig) -> None: model: ComposerModel = hydra.utils.instantiate(config.model) - # Check if this is training an autoencoder. If so, the optimizer needs different param groups if hasattr(model, 'autoencoder_loss'): + # Check if this is training an autoencoder. If so, the optimizer needs different param groups optimizer = make_autoencoder_optimizer(config, model) tokenizer = None + elif isinstance(model, ComposerTextToImageMMDiT): + # Check if this is training a transformer. If so, the optimizer needs different param groups + optimizer = make_transformer_optimizer(config, model) + tokenizer = model.tokenizer else: optimizer = hydra.utils.instantiate(config.optimizer, params=model.parameters()) tokenizer = model.tokenizer diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..e138f9fb --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,94 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from diffusion.models.models import text_to_image_transformer +from diffusion.models.t2i_transformer import patchify, unpatchify +from diffusion.models.transformer import get_multidimensional_position_embeddings + + +def test_multidimensional_position_embeddings(): + B = 32 + D = 2 + T = 16 + F = 64 + position_embeddings = torch.randn(D, T, F) + # Coords should be shape (B, S, D). So for sequence element B, S, one should get D embeddings. + # These should correspond to the D elements for which T = S in the position embeddings. + coords = torch.tensor([(i, j) for i in range(3) for j in range(3)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + S = coords.shape[1] + # Get the posistion embeddings from the coords + sequenced_embeddings = get_multidimensional_position_embeddings(position_embeddings, coords) + # Test that they are the right shape + assert sequenced_embeddings.shape == (B, S, F, D) + # Test that the embeddings are correct + assert torch.allclose(sequenced_embeddings[0, 0, :, 0], position_embeddings[0, 0, :]) + assert torch.allclose(sequenced_embeddings[1, 2, :, 1], position_embeddings[1, coords[1, 2, 1], :]) + + +@pytest.mark.parametrize('patch_size', [1, 2, 4]) +@pytest.mark.parametrize('batch_size', [1, 4]) +@pytest.mark.parametrize('C', [3, 4]) +@pytest.mark.parametrize('H', [32, 64]) +@pytest.mark.parametrize('W', [32, 64]) +def test_patch_and_unpatch(patch_size, batch_size, C, H, W): + # Fake image data + image = torch.randn(batch_size, C, H, W) + # Patchify + image_patches, image_coords = patchify(image, patch_size) + # Verify patches are the correct size + assert image_patches.shape == (batch_size, H * W // patch_size**2, C * patch_size**2) + # Verify coords are the correct size + assert image_coords.shape == (batch_size, H * W // patch_size**2, 2) + # Unpatchify + image_recon = [unpatchify(image_patches[i], image_coords[i], patch_size) for i in range(image_patches.shape[0])] + # Verify reconstructed image is the correct size + assert len(image_recon) == batch_size + assert image_recon[0].shape == (C, H, W) + # Verify reconstructed image is close to the original + for i in range(batch_size): + assert torch.allclose(image_recon[i], image[i], atol=1e-6) + + +@pytest.mark.gpu +def test_t2i_transformer_forward(): + # fp16 vae does not run on cpu + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = text_to_image_transformer(num_layers=2) + batch_size = 1 + H = 32 + W = 32 + image = torch.randn(batch_size, 3, H, W, device=device).half() + caption = torch.randint(low=0, high=128, size=( + batch_size, + 77, + ), dtype=torch.long, device=device) + caption_mask = torch.ones_like(caption, dtype=torch.bool, device=device) + batch = {'image': image, 'captions': caption, 'attention_mask': caption_mask} + outputs = model(batch) # model.forward generates the unet output noise or v_pred target. + # Desired output shape + seq_len = H / (8 * 2) * W / (8 * 2) + output_shape = (1, seq_len, 4 * 2 * 2) + assert outputs['predictions'].shape == output_shape + assert outputs['targets'].shape == output_shape + + +@pytest.mark.gpu +@pytest.mark.parametrize('guidance_scale', [0.0, 3.0]) +@pytest.mark.parametrize('negative_prompt', [None, 'so cool']) +def test_t2i_transformer_generate(guidance_scale, negative_prompt): + model = model = text_to_image_transformer(num_layers=2) + output = model.generate( + prompt='a cool doge', + negative_prompt=negative_prompt, + num_inference_steps=1, + num_images_per_prompt=1, + height=64, + width=64, + guidance_scale=guidance_scale, + progress_bar=False, + ) + assert output.shape == (1, 3, 32, 32)