From 293d30d7b05966cd53ea69a70400e713d92ffb80 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sat, 15 Jun 2024 23:35:47 +0000 Subject: [PATCH 01/27] Initial transformer --- diffusion/models/transformer.py | 464 ++++++++++++++++++++++++++++++++ 1 file changed, 464 insertions(+) create mode 100644 diffusion/models/transformer.py diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py new file mode 100644 index 00000000..90558dbc --- /dev/null +++ b/diffusion/models/transformer.py @@ -0,0 +1,464 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Diffusion Transformer model.""" + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from composer.models import ComposerModel +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm + + +def modulate(x, shift, scale): + """Modulate the input with the shift and scale.""" + return x * (1.0 + scale) + shift + + +def timestep_embedding(timesteps, dim, max_period=10000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class SelfAttention(nn.Module): + """Standard self attention layer that supports masking.""" + + def __init__(self, num_features, num_heads): + super().__init__() + self.num_features = num_features + self.num_heads = num_heads + # Linear layer to get q, k, and v + self.qkv = nn.Linear(self.num_features, 3 * self.num_features) + # QK layernorms + self.q_norm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.k_norm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + # Linear layer to get the output + self.output_layer = nn.Linear(self.num_features, self.num_features) + # Initialize all biases to zero + nn.init.zeros_(self.qkv.bias) + nn.init.zeros_(self.output_layer.bias) + # Init the standard deviation of the weights to 0.02 + nn.init.normal_(self.qkv.weight, std=0.02) + nn.init.normal_(self.output_layer.weight, std=0.02) + + def forward(self, x, mask=None): + # Get the shape of the input + B, T, C = x.size() + # Calculate the query, key, and values all in one go + q, k, v = self.qkv(x).chunk(3, dim=-1) + q = self.q_norm(q) + k = self.k_norm(k) + # After this, q, k, and v will have shape (B, T, C) + # Reshape the query, key, and values for multi-head attention + # Also want to swap the sequence length and the head dimension for later matmuls + q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) + k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) + v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) + # Native torch attention + attention_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # (B, H, T, C/H) + # Swap the sequence length and the head dimension back and get rid of num_heads. + attention_out = attention_out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C) + # Final linear layer to get the output + out = self.output_layer(attention_out) + return out + + +class DiTBlock(nn.Module): + """Transformer block that supports masking.""" + + def __init__(self, num_features, num_heads, expansion_factor=4): + super().__init__() + self.num_features = num_features + self.num_heads = num_heads + self.expansion_factor = expansion_factor + # Layer norm before the self attention + self.layer_norm_1 = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.attention = SelfAttention(self.num_features, self.num_heads) + # Layer norm before the MLP + self.layer_norm_2 = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + # MLP layers. The MLP expands and then contracts the features. + self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features) + self.nonlinearity = nn.GELU(approximate='tanh') + self.linear_2 = nn.Linear(self.expansion_factor * self.num_features, self.num_features) + # Initialize all biases to zero + nn.init.zeros_(self.linear_1.bias) + nn.init.zeros_(self.linear_2.bias) + # Initialize the linear layer weights to have a standard deviation of 0.02 + nn.init.normal_(self.linear_1.weight, std=0.02) + nn.init.normal_(self.linear_2.weight, std=0.02) + # AdaLN MLP + self.adaLN_mlp_linear = nn.Linear(self.num_features, 6 * self.num_features, bias=True) + # Initialize the modulations to zero. This will ensure the block acts as identity at initialization + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + + def forward(self, x, c, mask=None): + # Calculate the modulations. Each is shape (B, num_features). + mods = self.adaLN_mlp(c).unsqueeze(1).chunk(6, dim=2) + # Forward, with modulations + y = modulate(self.layer_norm_1(x), mods[0], mods[1]) + y = mods[2] * self.attention(y, mask=mask) + x = x + y + y = modulate(self.layer_norm_2(x), mods[3], mods[4]) + y = self.linear_1(y) + y = self.nonlinearity(y) + y = mods[5] * self.linear_2(y) + x = x + y + return x + + +class DiffusionTransformer(nn.Module): + """Transformer model for diffusion.""" + + def __init__(self, + num_features: int, + num_heads: int, + num_layers: int, + input_features: int = 192, + input_max_sequence_length: int = 1024, + input_dimension: int = 2, + conditioning_features: int = 1024, + conditioning_max_sequence_length: int = 77, + conditioning_dimension: int = 2, + expansion_factor: int = 4): + super().__init__() + # Params for the network architecture + self.num_features = num_features + self.num_heads = num_heads + self.num_layers = num_layers + self.expansion_factor = expansion_factor + # Params for input embeddings + self.input_features = input_features + self.input_dimension = input_dimension + self.input_max_sequence_length = input_max_sequence_length + # Params for conditioning embeddings + self.conditioning_features = conditioning_features + self.conditioning_dimension = conditioning_dimension + self.conditioning_max_sequence_length = conditioning_max_sequence_length + + # Projection layer for the input sequence + self.input_embedding = nn.Linear(self.input_features, self.num_features) + # Embedding layer for the input sequence + input_position_embedding = torch.randn(self.input_dimension, self.input_max_sequence_length, self.num_features) + input_position_embedding /= math.sqrt(self.num_features) + self.input_position_embedding = torch.nn.Parameter(input_position_embedding, requires_grad=True) + # Projection layer for the conditioning sequence + self.conditioning_embedding = nn.Linear(self.conditioning_features, self.num_features) + # Embedding layer for the conditioning sequence + conditioning_position_embedding = torch.randn(self.conditioning_dimension, + self.conditioning_max_sequence_length, self.num_features) + conditioning_position_embedding /= math.sqrt(self.num_features) + self.conditioning_position_embedding = torch.nn.Parameter(conditioning_position_embedding, requires_grad=True) + # Transformer blocks + self.transformer_blocks = nn.ModuleList([ + DiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor) + for _ in range(self.num_layers) + ]) + # Output projection layer + self.final_norm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.final_linear = nn.Linear(self.num_features, self.input_features) + # Init the output layer to zero + nn.init.zeros_(self.final_linear.weight) + nn.init.zeros_(self.final_linear.bias) + # AdaLN MLP for the output layer + self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features) + # Init the modulations to zero. This will ensure the block acts as identity at initialization + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + + def forward(self, + x, + input_coords, + t, + conditioning=None, + conditioning_coords=None, + input_mask=None, + conditioning_mask=None): + # TODO: Fix embeddings, fix embedding norms + # Embed the timestep + t = timestep_embedding(t, self.num_features) + + # Embed the input + y = self.input_embedding(x) # (B, T1, C) + # Get the input position embeddings and add them to the input + input_grid = torch.arange(self.input_dimension).view(1, 1, self.input_dimension).expand( + y.shape[0], y.shape[1], self.input_dimension) + y_position_embeddings = self.input_position_embedding[input_grid, + input_coords, :] # (B, T1, input_dimension, C) + y_position_embeddings = y_position_embeddings.sum(dim=2) # (B, T1, C) + y = y + y_position_embeddings # (B, T1, C) + if input_mask is None: + mask = torch.ones(x.shape[0], x.shape[1], device=x.device) + else: + mask = input_mask + + if conditioning is not None: + assert conditioning_coords is not None + # Embed the conditioning + c = self.conditioning_embedding(conditioning) # (B, T2, C) + # Get the conditioning position embeddings and add them to the conditioning + c_grid = torch.arange(self.conditioning_dimension).view(1, 1, self.conditioning_dimension).expand( + c.shape[0], c.shape[1], self.conditioning_dimension) + c_position_embeddings = self.conditioning_position_embedding[ + c_grid, conditioning_coords, :] # (B, T2, conditioning_dimension, C) + c_position_embeddings = c_position_embeddings.sum(dim=2) # (B, T2, C) + c = c + c_position_embeddings # (B, T2, C) + # Concatenate the input and conditioning sequences + y = torch.cat([y, c], dim=1) # (B, T1 + T2, C) + # Concatenate the masks + if conditioning_mask is None: + conditioning_mask = torch.ones(conditioning.shape[0], conditioning.shape[1], device=conditioning.device) + mask = torch.cat([mask, conditioning_mask], dim=1) # (B, T1 + T2) + + # Expand the mask to the right shape + mask = mask.bool() + mask = mask.unsqueeze(-1) & mask.unsqueeze(1) # (B, T1 + T2, T1 + T2) + identity = torch.eye(mask.shape[1], device=mask.device, + dtype=mask.dtype).unsqueeze(0).expand(mask.shape[0], -1, -1) + mask = mask | identity + mask = mask.unsqueeze(1) # (B, 1, T1 + T2, T1 + T2) + + # Pass through the transformer blocks + for block in self.transformer_blocks: + y = block(y, t, mask=mask) + # Throw away the conditioning tokens + y = y[:, 0:x.shape[1], :] + # Pass through the output layers to get the right number of elements + mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) + y = modulate(self.final_norm(y), mods[0], mods[1]) + y = self.final_linear(y) + return y + + +class ComposerDiffusionTransformer(ComposerModel): + """Diffusion transformer ComposerModel. + + Args: + model (DiffusionTransformer): Core diffusion model. + prediction_type (str): The type of prediction to use. Currently `epsilon`, `v_prediction` are supported. + T_max (int): The maximum number of timesteps. Default: 1000. + input_key (str): The name of the inputs in the dataloader batch. Default: `input`. + input_coords_key (str): The name of the input coordinates in the dataloader batch. Default: `input_coords`. + input_mask_key (str): The name of the input mask in the dataloader batch. Default: `input_mask`. + conditioning_key (str): The name of the conditioning info in the dataloader batch. Default: `conditioning`. + conditioning_coords_key (str): The name of the conditioning coordinates in the dataloader batch. Default: `conditioning_coords`. + conditioning_mask_key (str): The name of the conditioning mask in the dataloader batch. Default: `conditioning_mask`. + """ + + def __init__( + self, + model: DiffusionTransformer, + prediction_type: str = 'epsilon', + T_max: int = 1000, + input_key: str = 'input', + input_coords_key: str = 'input_coords', + input_mask_key: str = 'input_mask', + conditioning_key: str = 'conditioning', + conditioning_coords_key: str = 'conditioning_coords', + conditioning_mask_key: str = 'conditioning_mask', + ): + super().__init__() + self.model = model + self.model._fsdp_wrap = True + + # Diffusion parameters + self.prediction_type = prediction_type.lower() + if self.prediction_type not in ['epsilon', 'v_prediction']: + raise ValueError(f'Unrecognized prediction type {self.prediction_type}') + self.T_max = T_max + + # Set up input keys + self.input_key = input_key + self.input_coords_key = input_coords_key + self.input_mask_key = input_mask_key + # Set up conditioning keys + self.conditioning_key = conditioning_key + self.conditioning_coords_key = conditioning_coords_key + self.conditioning_mask_key = conditioning_mask_key + + # Params for MFU computation, subtract off the embedding params + self.n_params = sum(p.numel() for p in self.model.parameters()) + self.n_params -= self.model.input_position_embedding.numel() + self.n_params -= self.model.conditioning_position_embedding.numel() + + # Set up metrics + self.train_metrics = [MeanSquaredError()] + self.val_metrics = [MeanSquaredError()] + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def flops_per_batch(self, batch): + batch_size, input_seq_len = batch[self.input_key].shape[0:2] + cond_seq_len = batch[self.conditioning_key].shape[1] + seq_len = input_seq_len + cond_seq_len + # Calulate forward flops excluding attention + param_flops = 2 * self.n_params * batch_size * seq_len + # Calculate flops for attention layers + attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size + return 3 * param_flops + 3 * attention_flops + + def diffusion_forward_process(self, inputs: torch.Tensor): + """Diffusion forward process.""" + # Sample a timestep for every element in the batch + timesteps = self.T_max * torch.rand(inputs.shape[0], device=inputs.device, generator=self.rng_generator) + # Generate the noise, applied to the whole input sequence + noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) + # Add the noise to the latents according to the natural schedule + cos_t = torch.cos(timesteps * torch.pi / (2 * self.T_max)).view(-1, 1, 1) + sin_t = torch.sin(timesteps * torch.pi / (2 * self.T_max)).view(-1, 1, 1) + noised_inputs = cos_t * inputs + sin_t * noise + if self.prediction_type == 'epsilon': + # Get the (epsilon) targets + targets = noise + elif self.prediction_type == 'v_prediction': + # Get the (velocity) targets + targets = -sin_t * inputs + cos_t * noise + else: + raise ValueError(f'Unrecognized prediction type {self.prediction_type}') + # TODO: Implement other prediction types + return noised_inputs, targets, timesteps + + def forward(self, batch): + # Get the inputs + inputs = batch[self.input_key] + inputs_coords = batch[self.input_coords_key] + inputs_mask = batch[self.input_mask_key] + # Get the conditioning + conditioning = batch[self.conditioning_key] + conditioning_coords = batch[self.conditioning_coords_key] + conditioning_mask = batch[self.conditioning_mask_key] + # Diffusion forward process + noised_inputs, targets, timesteps = self.diffusion_forward_process(inputs) + # Forward through the model + model_out = self.model(noised_inputs, + inputs_coords, + timesteps, + conditioning=conditioning, + conditioning_coords=conditioning_coords, + input_mask=inputs_mask, + conditioning_mask=conditioning_mask) + return {'predictions': model_out, 'targets': targets, 'timesteps': timesteps} + + def loss(self, outputs, batch): + """MSE loss between outputs and targets.""" + losses = {} + # Need to mask out elements in the loss that are not present in the input + mask = batch[self.input_mask_key] # (B, T1), 1 if included, 0 otherwise. + loss = (outputs['predictions'] - outputs['targets'])**2 # (B, T1, C) + loss = loss.mean(dim=2) # (B, T1) + losses['total'] = (loss * mask).sum() / mask.sum() + return losses + + def eval_forward(self, batch, outputs=None): + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics_dict = {metric.__class__.__name__: metric for metric in self.train_metrics} + else: + metrics_dict = {metric.__class__.__name__: metric for metric in self.val_metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + if isinstance(metric, MeanSquaredError): + metric.update(outputs['predictions'], outputs['targets']) + else: + raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') + + def update_inputs(self, inputs, predictions, t, delta_t): + """Gets the input update.""" + angle = t * torch.pi / (2 * self.T_max) + cos_t = torch.cos(angle).view(-1, 1, 1) + sin_t = torch.sin(angle).view(-1, 1, 1) + if self.prediction_type == 'epsilon': + if angle == torch.pi / 2: + # Optimal update here is to do nothing. + pass + elif torch.abs(torch.pi / 2 - angle) < 1e-4: + # Need to avoid instability near t = T_max + inputs = inputs - (predictions - sin_t * inputs) + else: + inputs = inputs - (predictions - sin_t * inputs) * delta_t / cos_t + elif self.prediction_type == 'v_prediction': + inputs = inputs - delta_t * predictions + return inputs + + def generate(self, + input_coords: torch.Tensor, + input_mask: torch.Tensor, + conditioning: torch.Tensor, + conditioning_coords: torch.Tensor, + conditioning_mask: torch.Tensor, + guidance_scale: float = 7.0, + num_timesteps: int = 50, + progress_bar: bool = True, + seed: Optional[int] = None): + """Generate from the model.""" + device = next(self.model.parameters()).device + # Create rng for the generation + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) + # From the input coordinates, generate a noisy input sequence + inputs = torch.randn(*input_coords.shape[:-1], + self.model.input_features, + device=device, + generator=rng_generator) + # Set up for CFG + input_coords = torch.cat([input_coords, input_coords], dim=0) + input_mask = torch.cat([input_mask, input_mask], dim=0).to(device) + conditioning = torch.cat([torch.zeros_like(conditioning), conditioning], dim=0).to(device) + conditioning_coords = torch.cat([conditioning_coords, conditioning_coords], dim=0) + conditioning_mask = torch.cat([torch.zeros_like(conditioning_mask), conditioning_mask], dim=0).to(device) + # Make the timesteps + timesteps = torch.linspace(self.T_max, 0, num_timesteps + 1, device=device) + time_deltas = -torch.diff(timesteps) * (torch.pi / (2 * self.T_max)) + timesteps = timesteps[:-1] + # backward diffusion process + for i, t in enumerate(tqdm(timesteps, disable=not progress_bar)): + # Expand t to the batch size + t = t * torch.ones(inputs.shape[0], device=device) + # Duplicate the inputs for CFG + doubled_inputs = torch.cat([inputs, inputs], dim=0) + # Get the model prediction + model_out = self.model(doubled_inputs, + input_coords, + t, + conditioning=conditioning, + conditioning_coords=conditioning_coords, + input_mask=input_mask, + conditioning_mask=conditioning_mask) + # Do CFG + pred_uncond, pred_cond = model_out.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Update the inputs + inputs = self.update_inputs(inputs, pred, t, time_deltas[i]) + return inputs From a882d67a6d888f5dac3e8ebd885dd29d533346c4 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sun, 16 Jun 2024 07:18:22 +0000 Subject: [PATCH 02/27] Initial running composer model --- diffusion/models/models.py | 109 ++++++++++ diffusion/models/transformer.py | 339 +++++++++++++++++++++----------- 2 files changed, 335 insertions(+), 113 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index eaaee630..c8a933f1 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -18,6 +18,7 @@ from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer +from diffusion.models.transformer import ComposerTextToImageDiT, DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -496,6 +497,114 @@ def stable_diffusion_xl( return model +def text_to_image_transformer( + tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', + 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), + text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', + 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), + unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + prediction_type: str = 'epsilon', + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + beta_schedule: str = 'scaled_linear', + zero_terminal_snr: bool = False, + use_karras_sigmas: bool = False): + latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) + + if (isinstance(tokenizer_names, tuple) or + isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names): + raise ValueError('Number of tokenizer_names and text_encoder_names must be equal') + + # Make the tokenizer and text encoder + tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names) + text_encoder = MultiTextEncoder(model_names=text_encoder_names, encode_latents_in_fp16=True, pretrained_sdxl=False) + + precision = torch.float16 + # Make the autoencoder + if autoencoder_path is None: + if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': + raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') + downsample_factor = 8 + # Use the pretrained vae + try: + vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) + except: # for handling SDXL vae fp16 fixed checkpoint + vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision) + else: + # Use a custom autoencoder + vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision) + if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'): + raise ValueError( + 'Must specify latent scale when using a custom autoencoder without tracking latent statistics.') + if isinstance(latent_mean, str) and latent_mean == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_mean = tuple(latent_statistics['latent_channel_means']) + if isinstance(latent_std, str) and latent_std == 'latent_statistics': + assert isinstance(latent_statistics, dict) + latent_std = tuple(latent_statistics['latent_channel_stds']) + downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + + # Make the noise schedulers + noise_scheduler = DDPMScheduler(num_train_timesteps=1000, + beta_start=0.0000085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + variance_type='fixed_small', + clip_sample=False, + prediction_type=prediction_type, + sample_max_value=1.0, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, + beta_start=0.0000085, + beta_end=0.012, + beta_schedule=beta_schedule, + trained_betas=None, + prediction_type=prediction_type, + interpolation_type='linear', + use_karras_sigmas=use_karras_sigmas, + timestep_spacing='leading', + steps_offset=1, + rescale_betas_zero_snr=zero_terminal_snr) + + # Make the transformer model + transformer = DiffusionTransformer(num_features=256, + num_heads=4, + num_layers=4, + input_features=16, + input_max_sequence_length=1024, + input_dimension=2, + conditioning_features=768, + conditioning_max_sequence_length=77, + conditioning_dimension=1, + expansion_factor=4) + # Make the composer model + model = ComposerTextToImageDiT(model=transformer, + autoencoder=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + patch_size=2, + downsample_factor=8, + latent_channels=4, + image_key='image', + caption_key='captions', + caption_mask_key='attention_mask') + + if torch.cuda.is_available(): + model = DeviceGPU().module_to_device(model) + return model + + def build_autoencoder(input_channels: int = 3, output_channels: int = 3, hidden_channels: int = 128, diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 90558dbc..e33c557c 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -138,7 +138,7 @@ def __init__(self, input_dimension: int = 2, conditioning_features: int = 1024, conditioning_max_sequence_length: int = 77, - conditioning_dimension: int = 2, + conditioning_dimension: int = 1, expansion_factor: int = 4): super().__init__() # Params for the network architecture @@ -250,51 +250,85 @@ def forward(self, return y -class ComposerDiffusionTransformer(ComposerModel): - """Diffusion transformer ComposerModel. +class ComposerTextToImageDiT(ComposerModel): + """ComposerModel for text to image with a diffusion transformer. Args: model (DiffusionTransformer): Core diffusion model. + autoencoder (torch.nn.Module): HuggingFace or compatible vae. + must support `.encode()` and `decode()` functions. + text_encoder (torch.nn.Module): HuggingFace CLIP or LLM text enoder. + tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for + text_encoder. For a `CLIPTextModel` this will be the + `CLIPTokenizer` from HuggingFace transformers. + noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the forward diffusion process (training). + inference_noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the backward diffusion process (inference). prediction_type (str): The type of prediction to use. Currently `epsilon`, `v_prediction` are supported. - T_max (int): The maximum number of timesteps. Default: 1000. - input_key (str): The name of the inputs in the dataloader batch. Default: `input`. - input_coords_key (str): The name of the input coordinates in the dataloader batch. Default: `input_coords`. - input_mask_key (str): The name of the input mask in the dataloader batch. Default: `input_mask`. - conditioning_key (str): The name of the conditioning info in the dataloader batch. Default: `conditioning`. - conditioning_coords_key (str): The name of the conditioning coordinates in the dataloader batch. Default: `conditioning_coords`. - conditioning_mask_key (str): The name of the conditioning mask in the dataloader batch. Default: `conditioning_mask`. + latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to + 4 * (0.0,). Default: `None`. + latent_std (Optional[tuple[float]]): The standard deviations of the latent space. If not specified, + defaults to 4 * (1/0.13025,). Default: `None`. + patch_size (int): The size of the patches in the image latents. Default: `2`. + downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. + latent_channels (int): The number of channels in the autoencoder latent space. Default: `4`. + image_key (str): The name of the images in the dataloader batch. Default: `image`. + caption_key (str): The name of the caption in the dataloader batch. Default: `caption`. + caption_mask_key (str): The name of the caption mask in the dataloader batch. Default: `caption_mask`. """ def __init__( self, model: DiffusionTransformer, + autoencoder: torch.nn.Module, + text_encoder: torch.nn.Module, + tokenizer, + noise_scheduler, + inference_noise_scheduler, prediction_type: str = 'epsilon', - T_max: int = 1000, - input_key: str = 'input', - input_coords_key: str = 'input_coords', - input_mask_key: str = 'input_mask', - conditioning_key: str = 'conditioning', - conditioning_coords_key: str = 'conditioning_coords', - conditioning_mask_key: str = 'conditioning_mask', + latent_mean: Optional[tuple[float]] = None, + latent_std: Optional[tuple[float]] = None, + patch_size: int = 2, + downsample_factor: int = 8, + latent_channels: int = 4, + image_key: str = 'image', + caption_key: str = 'caption', + caption_mask_key: str = 'caption_mask', ): super().__init__() self.model = model - self.model._fsdp_wrap = True - - # Diffusion parameters + self.autoencoder = autoencoder + self.text_encoder = text_encoder + self.tokenizer = tokenizer + self.noise_scheduler = noise_scheduler + self.inference_scheduler = inference_noise_scheduler self.prediction_type = prediction_type.lower() if self.prediction_type not in ['epsilon', 'v_prediction']: raise ValueError(f'Unrecognized prediction type {self.prediction_type}') - self.T_max = T_max - - # Set up input keys - self.input_key = input_key - self.input_coords_key = input_coords_key - self.input_mask_key = input_mask_key - # Set up conditioning keys - self.conditioning_key = conditioning_key - self.conditioning_coords_key = conditioning_coords_key - self.conditioning_mask_key = conditioning_mask_key + if latent_mean is None: + self.latent_mean = 4 * (0.0) + if latent_std is None: + self.latent_std = 4 * (1 / 0.18215,) + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.patch_size = patch_size + self.downsample_factor = downsample_factor + self.latent_channels = latent_channels + self.image_key = image_key + self.caption_key = caption_key + self.caption_mask_key = caption_mask_key + + # freeze text_encoder during diffusion training and use half precision + self.autoencoder.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.autoencoder = self.autoencoder.half() + self.text_encoder = self.text_encoder.half() + + # Only FSDP wrap models we are training + self.model._fsdp_wrap = True + self.autoencoder._fsdp_wrap = False + self.text_encoder._fsdp_wrap = False # Params for MFU computation, subtract off the embedding params self.n_params = sum(p.numel() for p in self.model.parameters()) @@ -308,13 +342,19 @@ def __init__( # Optional rng generator self.rng_generator: Optional[torch.Generator] = None + def _apply(self, fn): + super(ComposerTextToImageDiT, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + def set_rng_generator(self, rng_generator: torch.Generator): """Sets the rng generator for the model.""" self.rng_generator = rng_generator def flops_per_batch(self, batch): - batch_size, input_seq_len = batch[self.input_key].shape[0:2] - cond_seq_len = batch[self.conditioning_key].shape[1] + batch_size, input_seq_len = batch[self.image_key].shape[0:2] + cond_seq_len = batch[self.caption_key].shape[1] seq_len = input_seq_len + cond_seq_len # Calulate forward flops excluding attention param_flops = 2 * self.n_params * batch_size * seq_len @@ -322,57 +362,92 @@ def flops_per_batch(self, batch): attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size return 3 * param_flops + 3 * attention_flops + def patchify(self, latents): + # Assume img is a tensor of shape [B, C, H, W] + B, C, H, W = latents.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0, 'Image dimensions must be divisible by patch_size' + # Reshape and permute to get non-overlapping patches + num_H_patches = H // self.patch_size + num_W_patches = W // self.patch_size + patches = latents.reshape(B, C, num_H_patches, self.patch_size, num_W_patches, self.patch_size) + patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * self.patch_size * self.patch_size) + # Generate coordinates for each patch + coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + return patches, coords + + def unpatchify(self, patches, coords): + # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] + C = patches.shape[1] // (self.patch_size * self.patch_size) + # Calculate the height and width of the original image from the coordinates + H = coords[:, 0].max() * self.patch_size + self.patch_size + W = coords[:, 1].max() * self.patch_size + self.patch_size + # Initialize an empty tensor for the reconstructed image + img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) + # Iterate over the patches and their coordinates + for patch, (y, x) in zip(patches, self.patch_size * coords): + # Reshape the patch to [C, patch_size, patch_size] + patch = patch.view(C, self.patch_size, self.patch_size) + # Place the patch in the corresponding location in the image + img[:, y:y + self.patch_size, x:x + self.patch_size] = patch + return img + def diffusion_forward_process(self, inputs: torch.Tensor): """Diffusion forward process.""" # Sample a timestep for every element in the batch - timesteps = self.T_max * torch.rand(inputs.shape[0], device=inputs.device, generator=self.rng_generator) + timesteps = torch.randint(0, + len(self.noise_scheduler), (inputs.shape[0],), + device=inputs.device, + generator=self.rng_generator) # Generate the noise, applied to the whole input sequence noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) - # Add the noise to the latents according to the natural schedule - cos_t = torch.cos(timesteps * torch.pi / (2 * self.T_max)).view(-1, 1, 1) - sin_t = torch.sin(timesteps * torch.pi / (2 * self.T_max)).view(-1, 1, 1) - noised_inputs = cos_t * inputs + sin_t * noise + # Add the noise to the latents according to the schedule + noised_inputs = self.noise_scheduler.add_noise(inputs, noise, timesteps) + # Generate the targets if self.prediction_type == 'epsilon': - # Get the (epsilon) targets targets = noise + elif self.prediction_type == 'sample': + targets = inputs elif self.prediction_type == 'v_prediction': - # Get the (velocity) targets - targets = -sin_t * inputs + cos_t * noise + targets = self.noise_scheduler.get_velocity(inputs, noise, timesteps) else: - raise ValueError(f'Unrecognized prediction type {self.prediction_type}') - # TODO: Implement other prediction types + raise ValueError( + f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') return noised_inputs, targets, timesteps def forward(self, batch): # Get the inputs - inputs = batch[self.input_key] - inputs_coords = batch[self.input_coords_key] - inputs_mask = batch[self.input_mask_key] - # Get the conditioning - conditioning = batch[self.conditioning_key] - conditioning_coords = batch[self.conditioning_coords_key] - conditioning_mask = batch[self.conditioning_mask_key] + image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] + # Get the text embeddings and image latents + with torch.cuda.amp.autocast(enabled=False): + latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data + text_encoder_out = self.text_encoder(caption, attention_mask=caption_mask) + text_embeddings = text_encoder_out[0] + # Make the text embedding coords + text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) + text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1).unsqueeze(-1) + # Zero dropped captions if needed + if 'drop_caption_mask' in batch.keys(): + text_embeddings *= batch['drop_caption_mask'].view(-1, 1, 1) + # Scale and patchify the latents + latents = (latents - self.latent_mean) / self.latent_std + latent_patches, latent_coords = self.patchify(latents) # Diffusion forward process - noised_inputs, targets, timesteps = self.diffusion_forward_process(inputs) + noised_inputs, targets, timesteps = self.diffusion_forward_process(latent_patches) # Forward through the model model_out = self.model(noised_inputs, - inputs_coords, + latent_coords, timesteps, - conditioning=conditioning, - conditioning_coords=conditioning_coords, - input_mask=inputs_mask, - conditioning_mask=conditioning_mask) + conditioning=text_embeddings, + conditioning_coords=text_embeddings_coords, + input_mask=None, + conditioning_mask=caption_mask) return {'predictions': model_out, 'targets': targets, 'timesteps': timesteps} def loss(self, outputs, batch): """MSE loss between outputs and targets.""" - losses = {} - # Need to mask out elements in the loss that are not present in the input - mask = batch[self.input_mask_key] # (B, T1), 1 if included, 0 otherwise. - loss = (outputs['predictions'] - outputs['targets'])**2 # (B, T1, C) - loss = loss.mean(dim=2) # (B, T1) - losses['total'] = (loss * mask).sum() / mask.sum() - return losses + loss = F.mse_loss(outputs['predictions'], outputs['targets']) + return loss def eval_forward(self, batch, outputs=None): # Skip this if outputs have already been computed, e.g. during training @@ -393,32 +468,38 @@ def update_metric(self, batch, outputs, metric): else: raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') - def update_inputs(self, inputs, predictions, t, delta_t): - """Gets the input update.""" - angle = t * torch.pi / (2 * self.T_max) - cos_t = torch.cos(angle).view(-1, 1, 1) - sin_t = torch.sin(angle).view(-1, 1, 1) - if self.prediction_type == 'epsilon': - if angle == torch.pi / 2: - # Optimal update here is to do nothing. - pass - elif torch.abs(torch.pi / 2 - angle) < 1e-4: - # Need to avoid instability near t = T_max - inputs = inputs - (predictions - sin_t * inputs) - else: - inputs = inputs - (predictions - sin_t * inputs) * delta_t / cos_t - elif self.prediction_type == 'v_prediction': - inputs = inputs - delta_t * predictions - return inputs + def combine_attention_masks(self, attention_mask): + if len(attention_mask.shape) == 2: + return attention_mask + elif len(attention_mask.shape) == 3: + encoder_attention_mask = attention_mask[:, 0] + for i in range(1, attention_mask.shape[1]): + encoder_attention_mask |= attention_mask[:, i] + return encoder_attention_mask + else: + raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_mask.shape}') + + def embed_prompt(self, prompt): + with torch.cuda.amp.autocast(enabled=False): + tokenized_out = self.tokenizer(prompt, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + tokenized_prompts = tokenized_out['input_ids'].to(self.text_encoder.device) + prompt_mask = tokenized_out['attention_mask'].to(self.text_encoder.device) + text_embeddings = self.text_encoder(tokenized_prompts, attention_mask=prompt_mask)[0] + prompt_mask = self.combine_attention_masks(prompt_mask) + return text_embeddings, prompt_mask def generate(self, - input_coords: torch.Tensor, - input_mask: torch.Tensor, - conditioning: torch.Tensor, - conditioning_coords: torch.Tensor, - conditioning_mask: torch.Tensor, + prompt: Optional[list] = None, + negative_prompt: Optional[list] = None, + height: Optional[int] = None, + width: Optional[int] = None, guidance_scale: float = 7.0, - num_timesteps: int = 50, + rescaled_guidance: Optional[float] = None, + num_inference_steps: int = 50, progress_bar: bool = True, seed: Optional[int] = None): """Generate from the model.""" @@ -427,38 +508,70 @@ def generate(self, rng_generator = torch.Generator(device=device) if seed: rng_generator = rng_generator.manual_seed(seed) - # From the input coordinates, generate a noisy input sequence - inputs = torch.randn(*input_coords.shape[:-1], - self.model.input_features, - device=device, - generator=rng_generator) + + # Get the text embeddings + if prompt is not None: + text_embeddings, prompt_mask = self.embed_prompt(prompt) + text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) + text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) + text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) + else: + raise ValueError('Prompt must be specified') + if negative_prompt is not None: + negative_text_embeddings, negative_prompt_mask = self.embed_prompt(negative_prompt) + else: + negative_text_embeddings = torch.zeros_like(text_embeddings) + negative_prompt_mask = torch.zeros_like(prompt_mask) + negative_text_embeddings_coords = torch.arange(negative_text_embeddings.shape[1], + device=negative_text_embeddings.device) + negative_text_embeddings_coords = negative_text_embeddings_coords.unsqueeze(0).expand( + negative_text_embeddings.shape[0], -1) + negative_text_embeddings_coords = negative_text_embeddings_coords.unsqueeze(-1) + + # Generate initial noise + latent_height = height // self.downsample_factor + latent_width = width // self.downsample_factor + latents = torch.randn(text_embeddings.shape[0], + self.latent_channels, + latent_height, + latent_width, + device=device) + latent_patches, latent_coords = self.patchify(latents) + # Set up for CFG - input_coords = torch.cat([input_coords, input_coords], dim=0) - input_mask = torch.cat([input_mask, input_mask], dim=0).to(device) - conditioning = torch.cat([torch.zeros_like(conditioning), conditioning], dim=0).to(device) - conditioning_coords = torch.cat([conditioning_coords, conditioning_coords], dim=0) - conditioning_mask = torch.cat([torch.zeros_like(conditioning_mask), conditioning_mask], dim=0).to(device) - # Make the timesteps - timesteps = torch.linspace(self.T_max, 0, num_timesteps + 1, device=device) - time_deltas = -torch.diff(timesteps) * (torch.pi / (2 * self.T_max)) - timesteps = timesteps[:-1] + text_embeddings = torch.cat([text_embeddings, negative_text_embeddings], dim=0) + text_embeddings_coords = torch.cat([text_embeddings_coords, negative_text_embeddings_coords], dim=0) + text_embeddings_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0) + latent_coords_input = torch.cat([latent_coords] * 2) + + # Prep for reverse process + self.inference_scheduler.set_timesteps(num_inference_steps) + # scale the initial noise by the standard deviation required by the scheduler + latent_patches = latent_patches * self.inference_scheduler.init_noise_sigma + # backward diffusion process - for i, t in enumerate(tqdm(timesteps, disable=not progress_bar)): - # Expand t to the batch size - t = t * torch.ones(inputs.shape[0], device=device) - # Duplicate the inputs for CFG - doubled_inputs = torch.cat([inputs, inputs], dim=0) + for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): + latent_patches_input = torch.cat([latent_patches] * 2) + latent_patches_input = self.inference_scheduler.scale_model_input(latent_patches_input, t) # Get the model prediction - model_out = self.model(doubled_inputs, - input_coords, - t, - conditioning=conditioning, - conditioning_coords=conditioning_coords, - input_mask=input_mask, - conditioning_mask=conditioning_mask) + model_out = self.model(latent_patches_input, + latent_coords_input, + t.unsqueeze(0), + conditioning=text_embeddings, + conditioning_coords=text_embeddings_coords, + input_mask=None, + conditioning_mask=text_embeddings_mask) # Do CFG pred_uncond, pred_cond = model_out.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) # Update the inputs - inputs = self.update_inputs(inputs, pred, t, time_deltas[i]) - return inputs + latent_patches = self.inference_scheduler.step(pred, t, latent_patches, generator=rng_generator).prev_sample + # Unpatchify the latents + latents = [self.unpatchify(latent_patches[i], latent_coords[i]) for i in range(latent_patches.shape[0])] + latents = torch.stack(latents) + # Scale the latents back to the original scale + latents = latents * self.latent_std + self.latent_mean + # Decode the latents + image = self.autoencoder.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image.detach() # (batch*num_images_per_prompt, channel, h, w) From 307bc12ddc0fb7714c918574192afa311912e44c Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sun, 16 Jun 2024 20:43:21 +0000 Subject: [PATCH 03/27] Configurable model size --- diffusion/models/models.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index c8a933f1..c5e32ea5 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -502,10 +502,16 @@ def text_to_image_transformer( 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), - unet_model_name: str = 'stabilityai/stable-diffusion-xl-base-1.0', vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', autoencoder_path: Optional[str] = None, autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + num_features: int = 1152, + num_heads: int = 16, + num_layers: int = 28, + input_max_sequence_length: int = 1024, + conditioning_features: int = 768, + conditioning_max_sequence_length: int = 77, + patch_size: int = 2, prediction_type: str = 'epsilon', latent_mean: Union[float, Tuple, str] = 0.0, latent_std: Union[float, Tuple, str] = 7.67754318618, @@ -528,6 +534,7 @@ def text_to_image_transformer( if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics': raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.') downsample_factor = 8 + autoencoder_channels = 4 # Use the pretrained vae try: vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision) @@ -546,6 +553,7 @@ def text_to_image_transformer( assert isinstance(latent_statistics, dict) latent_std = tuple(latent_statistics['latent_channel_stds']) downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) + autoencoder_channels = vae.config['latent_channels'] # Make the noise schedulers noise_scheduler = DDPMScheduler(num_train_timesteps=1000, @@ -573,14 +581,14 @@ def text_to_image_transformer( rescale_betas_zero_snr=zero_terminal_snr) # Make the transformer model - transformer = DiffusionTransformer(num_features=256, - num_heads=4, - num_layers=4, - input_features=16, - input_max_sequence_length=1024, + transformer = DiffusionTransformer(num_features=num_features, + num_heads=num_heads, + num_layers=num_layers, + input_features=autoencoder_channels * (patch_size ** 2), + input_max_sequence_length=input_max_sequence_length, input_dimension=2, - conditioning_features=768, - conditioning_max_sequence_length=77, + conditioning_features=conditioning_features, + conditioning_max_sequence_length=conditioning_max_sequence_length, conditioning_dimension=1, expansion_factor=4) # Make the composer model @@ -593,9 +601,9 @@ def text_to_image_transformer( prediction_type=prediction_type, latent_mean=latent_mean, latent_std=latent_std, - patch_size=2, - downsample_factor=8, - latent_channels=4, + patch_size=patch_size, + downsample_factor=downsample_factor, + latent_channels=autoencoder_channels, image_key='image', caption_key='captions', caption_mask_key='attention_mask') From e2dd53ec2eefc1f668e704368c5661a0fa2571ea Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sun, 16 Jun 2024 21:36:47 +0000 Subject: [PATCH 04/27] Fix calculation of input sequence length --- diffusion/models/models.py | 2 +- diffusion/models/transformer.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index c5e32ea5..6b445730 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -584,7 +584,7 @@ def text_to_image_transformer( transformer = DiffusionTransformer(num_features=num_features, num_heads=num_heads, num_layers=num_layers, - input_features=autoencoder_channels * (patch_size ** 2), + input_features=autoencoder_channels * (patch_size**2), input_max_sequence_length=input_max_sequence_length, input_dimension=2, conditioning_features=conditioning_features, diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index e33c557c..e83a090a 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -353,7 +353,9 @@ def set_rng_generator(self, rng_generator: torch.Generator): self.rng_generator = rng_generator def flops_per_batch(self, batch): - batch_size, input_seq_len = batch[self.image_key].shape[0:2] + batch_size = batch[self.image_key].shape[0] + height, width = batch[self.image_key].shape[2:] + input_seq_len = height * width / self.patch_size**2 cond_seq_len = batch[self.caption_key].shape[1] seq_len = input_seq_len + cond_seq_len # Calulate forward flops excluding attention From 3317a09a5bc2d7c303b15bed9c12efb06fe5d3fe Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sun, 16 Jun 2024 22:29:25 +0000 Subject: [PATCH 05/27] Forgot downsample factor in sequence length calc --- diffusion/models/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index e83a090a..a63631a2 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -355,7 +355,7 @@ def set_rng_generator(self, rng_generator: torch.Generator): def flops_per_batch(self, batch): batch_size = batch[self.image_key].shape[0] height, width = batch[self.image_key].shape[2:] - input_seq_len = height * width / self.patch_size**2 + input_seq_len = height * width / (self.patch_size**2 * self.downsample_factor**2) cond_seq_len = batch[self.caption_key].shape[1] seq_len = input_seq_len + cond_seq_len # Calulate forward flops excluding attention From 312c13ee89d6c66c73db333bdfe159933ffe2a26 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 17 Jun 2024 05:52:58 +0000 Subject: [PATCH 06/27] Need affines for the layernorms --- diffusion/models/transformer.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index a63631a2..d8f777cc 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -48,8 +48,8 @@ def __init__(self, num_features, num_heads): # Linear layer to get q, k, and v self.qkv = nn.Linear(self.num_features, 3 * self.num_features) # QK layernorms - self.q_norm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) - self.k_norm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.q_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + self.k_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) # Linear layer to get the output self.output_layer = nn.Linear(self.num_features, self.num_features) # Initialize all biases to zero @@ -90,10 +90,10 @@ def __init__(self, num_features, num_heads, expansion_factor=4): self.num_heads = num_heads self.expansion_factor = expansion_factor # Layer norm before the self attention - self.layer_norm_1 = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.layer_norm_1 = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) self.attention = SelfAttention(self.num_features, self.num_heads) # Layer norm before the MLP - self.layer_norm_2 = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.layer_norm_2 = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) # MLP layers. The MLP expands and then contracts the features. self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features) self.nonlinearity = nn.GELU(approximate='tanh') @@ -174,7 +174,7 @@ def __init__(self, for _ in range(self.num_layers) ]) # Output projection layer - self.final_norm = nn.LayerNorm(self.num_features, elementwise_affine=False, eps=1e-6) + self.final_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) self.final_linear = nn.Linear(self.num_features, self.input_features) # Init the output layer to zero nn.init.zeros_(self.final_linear.weight) @@ -304,7 +304,7 @@ def __init__( self.noise_scheduler = noise_scheduler self.inference_scheduler = inference_noise_scheduler self.prediction_type = prediction_type.lower() - if self.prediction_type not in ['epsilon', 'v_prediction']: + if self.prediction_type not in ['epsilon', 'sample', 'v_prediction']: raise ValueError(f'Unrecognized prediction type {self.prediction_type}') if latent_mean is None: self.latent_mean = 4 * (0.0) @@ -497,8 +497,8 @@ def embed_prompt(self, prompt): def generate(self, prompt: Optional[list] = None, negative_prompt: Optional[list] = None, - height: Optional[int] = None, - width: Optional[int] = None, + height: int = 256, + width: int = 256, guidance_scale: float = 7.0, rescaled_guidance: Optional[float] = None, num_inference_steps: int = 50, From 36b343250ee3d8aae9a9053d812abab4f3a45aac Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 17 Jun 2024 05:54:50 +0000 Subject: [PATCH 07/27] Turn off weight decay for biases, norms, and position embeddings --- diffusion/train.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/diffusion/train.py b/diffusion/train.py index 1d6799ff..f994c3a9 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -20,6 +20,7 @@ from torch.optim import Optimizer from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder +from diffusion.models.transformer import ComposerTextToImageDiT def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: @@ -50,6 +51,32 @@ def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Opti return optimizer +def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: + """Configures the optimizer for use with a transformer model.""" + print('Configuring optimizer for transformer') + assert isinstance(model, ComposerTextToImageDiT) + + # Turn off weight decay for the positional embeddings + no_decay = ['bias', 'layer_norm', 'position_embedding'] + params_with_no_decay = [] + params_with_decay = [] + for name, param in model.named_parameters(): + if any(nd in name for nd in no_decay): + print(f'No decay: {name}') + params_with_no_decay.append(param) + else: + params_with_decay.append(param) + no_decay_dict = dict(config.optimizer.items()) + no_decay_dict['params'] = params_with_no_decay + no_decay_dict['weight_decay'] = 0.0 + + decay_dict = dict(config.optimizer.items()) + decay_dict['params'] = params_with_decay + + optimizer = hydra.utils.instantiate(config.optimizer, [no_decay_dict, decay_dict]) + return optimizer + + def train(config: DictConfig) -> None: """Train a model. @@ -62,10 +89,14 @@ def train(config: DictConfig) -> None: model: ComposerModel = hydra.utils.instantiate(config.model) - # Check if this is training an autoencoder. If so, the optimizer needs different param groups if hasattr(model, 'autoencoder_loss'): + # Check if this is training an autoencoder. If so, the optimizer needs different param groups optimizer = make_autoencoder_optimizer(config, model) tokenizer = None + elif isinstance(model, ComposerTextToImageDiT): + # Check if this is training a transformer. If so, the optimizer needs different param groups + optimizer = make_transformer_optimizer(config, model) + tokenizer = model.tokenizer else: optimizer = hydra.utils.instantiate(config.optimizer, params=model.parameters()) tokenizer = model.tokenizer From 40b21cc38ff6ff9ed88ac27819af6f86150afad0 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Tue, 18 Jun 2024 05:10:07 +0000 Subject: [PATCH 08/27] Refactor and add tests --- diffusion/models/transformer.py | 109 +++++++++++++++++--------------- diffusion/train.py | 2 +- tests/test_transformer.py | 51 +++++++++++++++ 3 files changed, 110 insertions(+), 52 deletions(-) create mode 100644 tests/test_transformer.py diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index d8f777cc..c5c8bf70 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -57,7 +57,6 @@ def __init__(self, num_features, num_heads): nn.init.zeros_(self.output_layer.bias) # Init the standard deviation of the weights to 0.02 nn.init.normal_(self.qkv.weight, std=0.02) - nn.init.normal_(self.output_layer.weight, std=0.02) def forward(self, x, mask=None): # Get the shape of the input @@ -101,9 +100,6 @@ def __init__(self, num_features, num_heads, expansion_factor=4): # Initialize all biases to zero nn.init.zeros_(self.linear_1.bias) nn.init.zeros_(self.linear_2.bias) - # Initialize the linear layer weights to have a standard deviation of 0.02 - nn.init.normal_(self.linear_1.weight, std=0.02) - nn.init.normal_(self.linear_2.weight, std=0.02) # AdaLN MLP self.adaLN_mlp_linear = nn.Linear(self.num_features, 6 * self.num_features, bias=True) # Initialize the modulations to zero. This will ensure the block acts as identity at initialization @@ -126,6 +122,17 @@ def forward(self, x, c, mask=None): return x +def get_multidimensional_position_embeddings(position_embeddings, coords): + """Position embeddings are shape (D, T, F). Coords are shape (B, S, D).""" + B, S, D = coords.shape + F = position_embeddings.shape[2] + coords = coords.reshape(B * S, D) + sequenced_embeddings = [position_embeddings[d, coords[:, d]] for d in range(D)] + sequenced_embeddings = torch.stack(sequenced_embeddings, dim=-1) + sequenced_embeddings = sequenced_embeddings.view(B, S, F, D) + return sequenced_embeddings # (B, S, F, D) + + class DiffusionTransformer(nn.Module): """Transformer model for diffusion.""" @@ -197,15 +204,11 @@ def forward(self, # TODO: Fix embeddings, fix embedding norms # Embed the timestep t = timestep_embedding(t, self.num_features) - # Embed the input y = self.input_embedding(x) # (B, T1, C) # Get the input position embeddings and add them to the input - input_grid = torch.arange(self.input_dimension).view(1, 1, self.input_dimension).expand( - y.shape[0], y.shape[1], self.input_dimension) - y_position_embeddings = self.input_position_embedding[input_grid, - input_coords, :] # (B, T1, input_dimension, C) - y_position_embeddings = y_position_embeddings.sum(dim=2) # (B, T1, C) + y_position_embeddings = get_multidimensional_position_embeddings(self.input_position_embedding, input_coords) + y_position_embeddings = y_position_embeddings.sum(dim=-1) # (B, T1, C) y = y + y_position_embeddings # (B, T1, C) if input_mask is None: mask = torch.ones(x.shape[0], x.shape[1], device=x.device) @@ -217,11 +220,9 @@ def forward(self, # Embed the conditioning c = self.conditioning_embedding(conditioning) # (B, T2, C) # Get the conditioning position embeddings and add them to the conditioning - c_grid = torch.arange(self.conditioning_dimension).view(1, 1, self.conditioning_dimension).expand( - c.shape[0], c.shape[1], self.conditioning_dimension) - c_position_embeddings = self.conditioning_position_embedding[ - c_grid, conditioning_coords, :] # (B, T2, conditioning_dimension, C) - c_position_embeddings = c_position_embeddings.sum(dim=2) # (B, T2, C) + c_position_embeddings = get_multidimensional_position_embeddings(self.conditioning_position_embedding, + conditioning_coords) + c_position_embeddings = c_position_embeddings.sum(dim=-1) # (B, T2, C) c = c + c_position_embeddings # (B, T2, C) # Concatenate the input and conditioning sequences y = torch.cat([y, c], dim=1) # (B, T1 + T2, C) @@ -250,6 +251,40 @@ def forward(self, return y +def patchify(latents, patch_size): + """Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size].""" + # Assume img is a tensor of shape [B, C, H, W] + B, C, H, W = latents.shape + assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' + # Reshape and permute to get non-overlapping patches + num_H_patches = H // patch_size + num_W_patches = W // patch_size + patches = latents.reshape(B, C, num_H_patches, patch_size, num_W_patches, patch_size) + patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * patch_size * patch_size) + # Generate coordinates for each patch + coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + return patches, coords + + +def unpatchify(patches, coords, patch_size): + """Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W].""" + # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] + C = patches.shape[1] // (patch_size * patch_size) + # Calculate the height and width of the original image from the coordinates + H = coords[:, 0].max() * patch_size + patch_size + W = coords[:, 1].max() * patch_size + patch_size + # Initialize an empty tensor for the reconstructed image + img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) + # Iterate over the patches and their coordinates + for patch, (y, x) in zip(patches, patch_size * coords): + # Reshape the patch to [C, patch_size, patch_size] + patch = patch.view(C, patch_size, patch_size) + # Place the patch in the corresponding location in the image + img[:, y:y + patch_size, x:x + patch_size] = patch + return img + + class ComposerTextToImageDiT(ComposerModel): """ComposerModel for text to image with a diffusion transformer. @@ -364,36 +399,6 @@ def flops_per_batch(self, batch): attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size return 3 * param_flops + 3 * attention_flops - def patchify(self, latents): - # Assume img is a tensor of shape [B, C, H, W] - B, C, H, W = latents.shape - assert H % self.patch_size == 0 and W % self.patch_size == 0, 'Image dimensions must be divisible by patch_size' - # Reshape and permute to get non-overlapping patches - num_H_patches = H // self.patch_size - num_W_patches = W // self.patch_size - patches = latents.reshape(B, C, num_H_patches, self.patch_size, num_W_patches, self.patch_size) - patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * self.patch_size * self.patch_size) - # Generate coordinates for each patch - coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) - coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) - return patches, coords - - def unpatchify(self, patches, coords): - # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] - C = patches.shape[1] // (self.patch_size * self.patch_size) - # Calculate the height and width of the original image from the coordinates - H = coords[:, 0].max() * self.patch_size + self.patch_size - W = coords[:, 1].max() * self.patch_size + self.patch_size - # Initialize an empty tensor for the reconstructed image - img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) - # Iterate over the patches and their coordinates - for patch, (y, x) in zip(patches, self.patch_size * coords): - # Reshape the patch to [C, patch_size, patch_size] - patch = patch.view(C, self.patch_size, self.patch_size) - # Place the patch in the corresponding location in the image - img[:, y:y + self.patch_size, x:x + self.patch_size] = patch - return img - def diffusion_forward_process(self, inputs: torch.Tensor): """Diffusion forward process.""" # Sample a timestep for every element in the batch @@ -433,7 +438,7 @@ def forward(self, batch): text_embeddings *= batch['drop_caption_mask'].view(-1, 1, 1) # Scale and patchify the latents latents = (latents - self.latent_mean) / self.latent_std - latent_patches, latent_coords = self.patchify(latents) + latent_patches, latent_coords = patchify(latents, self.patch_size) # Diffusion forward process noised_inputs, targets, timesteps = self.diffusion_forward_process(latent_patches) # Forward through the model @@ -538,13 +543,13 @@ def generate(self, latent_height, latent_width, device=device) - latent_patches, latent_coords = self.patchify(latents) + latent_patches, latent_coords = patchify(latents, self.patch_size) # Set up for CFG text_embeddings = torch.cat([text_embeddings, negative_text_embeddings], dim=0) text_embeddings_coords = torch.cat([text_embeddings_coords, negative_text_embeddings_coords], dim=0) text_embeddings_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0) - latent_coords_input = torch.cat([latent_coords] * 2) + latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) # Prep for reverse process self.inference_scheduler.set_timesteps(num_inference_steps) @@ -553,7 +558,7 @@ def generate(self, # backward diffusion process for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): - latent_patches_input = torch.cat([latent_patches] * 2) + latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) latent_patches_input = self.inference_scheduler.scale_model_input(latent_patches_input, t) # Get the model prediction model_out = self.model(latent_patches_input, @@ -564,12 +569,14 @@ def generate(self, input_mask=None, conditioning_mask=text_embeddings_mask) # Do CFG - pred_uncond, pred_cond = model_out.chunk(2, dim=0) + pred_cond, pred_uncond = model_out.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) # Update the inputs latent_patches = self.inference_scheduler.step(pred, t, latent_patches, generator=rng_generator).prev_sample # Unpatchify the latents - latents = [self.unpatchify(latent_patches[i], latent_coords[i]) for i in range(latent_patches.shape[0])] + latents = [ + unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) + ] latents = torch.stack(latents) # Scale the latents back to the original scale latents = latents * self.latent_std + self.latent_mean diff --git a/diffusion/train.py b/diffusion/train.py index f994c3a9..d8b6ce84 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -62,7 +62,7 @@ def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Opti params_with_decay = [] for name, param in model.named_parameters(): if any(nd in name for nd in no_decay): - print(f'No decay: {name}') + #print(f'No decay: {name}') params_with_no_decay.append(param) else: params_with_decay.append(param) diff --git a/tests/test_transformer.py b/tests/test_transformer.py new file mode 100644 index 00000000..da70a8a6 --- /dev/null +++ b/tests/test_transformer.py @@ -0,0 +1,51 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from diffusion.models.transformer import get_multidimensional_position_embeddings, patchify, unpatchify + + +def test_multidimensional_position_embeddings(): + B = 32 + D = 2 + T = 16 + F = 64 + position_embeddings = torch.randn(D, T, F) + # Coords should be shape (B, S, D). So for sequence element B, S, one should get D embeddings. + # These should correspond to the D elements for which T = S in the position embeddings. + coords = torch.tensor([(i, j) for i in range(3) for j in range(3)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + S = coords.shape[1] + # Get the posistion embeddings from the coords + sequenced_embeddings = get_multidimensional_position_embeddings(position_embeddings, coords) + # Test that they are the right shape + assert sequenced_embeddings.shape == (B, S, F, D) + # Test that the embeddings are correct + assert torch.allclose(sequenced_embeddings[0, 0, :, 0], position_embeddings[0, 0, :]) + assert torch.allclose(sequenced_embeddings[1, 2, :, 1], position_embeddings[1, coords[1, 2, 1], :]) + + +@pytest.mark.parametrize('patch_size', [1, 2, 4]) +@pytest.mark.parametrize('batch_size', [1, 4]) +@pytest.mark.parametrize('C', [3, 4]) +@pytest.mark.parametrize('H', [32, 64]) +@pytest.mark.parametrize('W', [32, 64]) +def test_patch_and_unpatch(patch_size, batch_size, C, H, W): + # Fake image data + image = torch.randn(batch_size, C, H, W) + # Patchify + image_patches, image_coords = patchify(image, patch_size) + # Verify patches are the correct size + assert image_patches.shape == (batch_size, H * W // patch_size**2, C * patch_size**2) + # Verify coords are the correct size + assert image_coords.shape == (batch_size, H * W // patch_size**2, 2) + # Unpatchify + image_recon = [unpatchify(image_patches[i], image_coords[i], patch_size) for i in range(image_patches.shape[0])] + # Verify reconstructed image is the correct size + assert len(image_recon) == batch_size + assert image_recon[0].shape == (C, H, W) + # Verify reconstructed image is close to the original + for i in range(batch_size): + assert torch.allclose(image_recon[i], image[i], atol=1e-6) From b396fa67f0c6e30befa693ae66aea44b6c29cf9a Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 19 Jun 2024 17:54:42 +0000 Subject: [PATCH 09/27] Wrapping, pooled conditioning, flop calc fix --- diffusion/models/transformer.py | 38 ++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index c5c8bf70..ed61f106 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -193,6 +193,16 @@ def __init__(self, nn.init.zeros_(self.adaLN_mlp_linear.bias) self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + def fsdp_wrap_fn(self, module): + if isinstance(module, DiTBlock): + return True + return False + + def activation_checkpointing_fn(self, module): + if isinstance(module, DiTBlock): + return True + return False + def forward(self, x, input_coords, @@ -200,10 +210,13 @@ def forward(self, conditioning=None, conditioning_coords=None, input_mask=None, - conditioning_mask=None): - # TODO: Fix embeddings, fix embedding norms + conditioning_mask=None, + constant_conditioning=None): # Embed the timestep t = timestep_embedding(t, self.num_features) + # Optionally add constant conditioning + if constant_conditioning is not None: + t = t + constant_conditioning # Embed the input y = self.input_embedding(x) # (B, T1, C) # Get the input position embeddings and add them to the input @@ -354,6 +367,9 @@ def __init__( self.caption_key = caption_key self.caption_mask_key = caption_mask_key + # Projection layer for the pooled text embeddings + self.pooled_projection_layer = nn.Linear(self.model.conditioning_features, self.model.num_features) + # freeze text_encoder during diffusion training and use half precision self.autoencoder.requires_grad_(False) self.text_encoder.requires_grad_(False) @@ -365,8 +381,12 @@ def __init__( self.autoencoder._fsdp_wrap = False self.text_encoder._fsdp_wrap = False - # Params for MFU computation, subtract off the embedding params + # Params for MFU computation + # First calc the AdaLN params separately + self.adaLN_params = sum(p.numel() for n, p in self.model.named_parameters() if 'adaLN_mlp_linear' in n) self.n_params = sum(p.numel() for p in self.model.parameters()) + self.n_params -= self.adaLN_params + # Subtract off the embedding params self.n_params -= self.model.input_position_embedding.numel() self.n_params -= self.model.conditioning_position_embedding.numel() @@ -395,6 +415,8 @@ def flops_per_batch(self, batch): seq_len = input_seq_len + cond_seq_len # Calulate forward flops excluding attention param_flops = 2 * self.n_params * batch_size * seq_len + # Include flops from adaln + param_flops += 2 * self.adaLN_params * batch_size # Calculate flops for attention layers attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size return 3 * param_flops + 3 * attention_flops @@ -430,12 +452,16 @@ def forward(self, batch): latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data text_encoder_out = self.text_encoder(caption, attention_mask=caption_mask) text_embeddings = text_encoder_out[0] + pooled_text_embeddings = text_encoder_out[1] # Make the text embedding coords text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1).unsqueeze(-1) + # Project the text embeddings + pooled_text_embeddings = self.pooled_projection_layer(pooled_text_embeddings) # Zero dropped captions if needed if 'drop_caption_mask' in batch.keys(): text_embeddings *= batch['drop_caption_mask'].view(-1, 1, 1) + pooled_text_embeddings *= batch['drop_caption_mask'].view(-1, 1) # Scale and patchify the latents latents = (latents - self.latent_mean) / self.latent_std latent_patches, latent_coords = patchify(latents, self.patch_size) @@ -448,7 +474,8 @@ def forward(self, batch): conditioning=text_embeddings, conditioning_coords=text_embeddings_coords, input_mask=None, - conditioning_mask=caption_mask) + conditioning_mask=caption_mask, + constant_conditioning=pooled_text_embeddings) return {'predictions': model_out, 'targets': targets, 'timesteps': timesteps} def loss(self, outputs, batch): @@ -499,6 +526,7 @@ def embed_prompt(self, prompt): prompt_mask = self.combine_attention_masks(prompt_mask) return text_embeddings, prompt_mask + @torch.no_grad() def generate(self, prompt: Optional[list] = None, negative_prompt: Optional[list] = None, @@ -563,7 +591,7 @@ def generate(self, # Get the model prediction model_out = self.model(latent_patches_input, latent_coords_input, - t.unsqueeze(0), + t.unsqueeze(0).to(device), conditioning=text_embeddings, conditioning_coords=text_embeddings_coords, input_mask=None, From 36df692d34e3e1056ea6d68cfa0773bf6bc1af87 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 19 Jun 2024 20:20:48 +0000 Subject: [PATCH 10/27] Initial working MMDiT implementation --- diffusion/models/models.py | 39 ++-- diffusion/models/transformer.py | 311 +++++++++++++++++++------------- diffusion/train.py | 8 +- 3 files changed, 215 insertions(+), 143 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 6b445730..706bfe9b 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -18,7 +18,7 @@ from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer -from diffusion.models.transformer import ComposerTextToImageDiT, DiffusionTransformer +from diffusion.models.transformer import ComposerTextToImageMMDiT, DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -518,6 +518,7 @@ def text_to_image_transformer( beta_schedule: str = 'scaled_linear', zero_terminal_snr: bool = False, use_karras_sigmas: bool = False): + """Text to image transformer training setup.""" latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) if (isinstance(tokenizer_names, tuple) or @@ -554,6 +555,12 @@ def text_to_image_transformer( latent_std = tuple(latent_statistics['latent_channel_stds']) downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1) autoencoder_channels = vae.config['latent_channels'] + assert isinstance(vae, torch.nn.Module) + if isinstance(latent_mean, float): + latent_mean = (latent_mean,) * autoencoder_channels + if isinstance(latent_std, float): + latent_std = (latent_std,) * autoencoder_channels + assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) # Make the noise schedulers noise_scheduler = DDPMScheduler(num_train_timesteps=1000, @@ -592,21 +599,21 @@ def text_to_image_transformer( conditioning_dimension=1, expansion_factor=4) # Make the composer model - model = ComposerTextToImageDiT(model=transformer, - autoencoder=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - noise_scheduler=noise_scheduler, - inference_noise_scheduler=inference_noise_scheduler, - prediction_type=prediction_type, - latent_mean=latent_mean, - latent_std=latent_std, - patch_size=patch_size, - downsample_factor=downsample_factor, - latent_channels=autoencoder_channels, - image_key='image', - caption_key='captions', - caption_mask_key='attention_mask') + model = ComposerTextToImageMMDiT(model=transformer, + autoencoder=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + noise_scheduler=noise_scheduler, + inference_noise_scheduler=inference_noise_scheduler, + prediction_type=prediction_type, + latent_mean=latent_mean, + latent_std=latent_std, + patch_size=patch_size, + downsample_factor=downsample_factor, + latent_channels=autoencoder_channels, + image_key='image', + caption_key='captions', + caption_mask_key='attention_mask') if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index ed61f106..fc3933fe 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -38,36 +38,99 @@ def timestep_embedding(timesteps, dim, max_period=10000): return embedding -class SelfAttention(nn.Module): - """Standard self attention layer that supports masking.""" +def get_multidimensional_position_embeddings(position_embeddings, coords): + """Position embeddings are shape (D, T, F). Coords are shape (B, S, D).""" + B, S, D = coords.shape + F = position_embeddings.shape[2] + coords = coords.reshape(B * S, D) + sequenced_embeddings = [position_embeddings[d, coords[:, d]] for d in range(D)] + sequenced_embeddings = torch.stack(sequenced_embeddings, dim=-1) + sequenced_embeddings = sequenced_embeddings.view(B, S, F, D) + return sequenced_embeddings # (B, S, F, D) - def __init__(self, num_features, num_heads): + +def patchify(latents, patch_size): + """Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size].""" + # Assume img is a tensor of shape [B, C, H, W] + B, C, H, W = latents.shape + assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' + # Reshape and permute to get non-overlapping patches + num_H_patches = H // patch_size + num_W_patches = W // patch_size + patches = latents.reshape(B, C, num_H_patches, patch_size, num_W_patches, patch_size) + patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * patch_size * patch_size) + # Generate coordinates for each patch + coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + return patches, coords + + +def unpatchify(patches, coords, patch_size): + """Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W].""" + # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] + C = patches.shape[1] // (patch_size * patch_size) + # Calculate the height and width of the original image from the coordinates + H = coords[:, 0].max() * patch_size + patch_size + W = coords[:, 1].max() * patch_size + patch_size + # Initialize an empty tensor for the reconstructed image + img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) + # Iterate over the patches and their coordinates + for patch, (y, x) in zip(patches, patch_size * coords): + # Reshape the patch to [C, patch_size, patch_size] + patch = patch.view(C, patch_size, patch_size) + # Place the patch in the corresponding location in the image + img[:, y:y + patch_size, x:x + patch_size] = patch + return img + + +class PreAttentionBlock(nn.Module): + """Block to compute QKV before attention.""" + + def __init__(self, num_features): super().__init__() self.num_features = num_features - self.num_heads = num_heads + + # AdaLN MLP for pre-attention. Initialized to zero so modulation acts as identity at initialization. + self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + # Input layernorm + self.input_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) # Linear layer to get q, k, and v self.qkv = nn.Linear(self.num_features, 3 * self.num_features) - # QK layernorms + # QK layernorms. Original MMDiT used RMSNorm here. self.q_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) self.k_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) - # Linear layer to get the output - self.output_layer = nn.Linear(self.num_features, self.num_features) # Initialize all biases to zero nn.init.zeros_(self.qkv.bias) - nn.init.zeros_(self.output_layer.bias) - # Init the standard deviation of the weights to 0.02 + # Init the standard deviation of the weights to 0.02 as is tradition nn.init.normal_(self.qkv.weight, std=0.02) - def forward(self, x, mask=None): - # Get the shape of the input - B, T, C = x.size() + def forward(self, x, t): + # Calculate the modulations + mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) + # Forward, with modulations + x = modulate(self.input_norm(x), mods[0], mods[1]) # Calculate the query, key, and values all in one go q, k, v = self.qkv(x).chunk(3, dim=-1) q = self.q_norm(q) k = self.k_norm(k) - # After this, q, k, and v will have shape (B, T, C) + return q, k, v + + +class SelfAttention(nn.Module): + """Standard self attention layer that supports masking.""" + + def __init__(self, num_features, num_heads): + super().__init__() + self.num_features = num_features + self.num_heads = num_heads + + def forward(self, q, k, v, mask=None): + # Get the shape of the inputs + B, T, C = v.size() # Reshape the query, key, and values for multi-head attention - # Also want to swap the sequence length and the head dimension for later matmuls q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) @@ -75,62 +138,89 @@ def forward(self, x, mask=None): attention_out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) # (B, H, T, C/H) # Swap the sequence length and the head dimension back and get rid of num_heads. attention_out = attention_out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, C) - # Final linear layer to get the output - out = self.output_layer(attention_out) - return out + return attention_out -class DiTBlock(nn.Module): - """Transformer block that supports masking.""" +class PostAttentionBlock(nn.Module): + """Block to postprocess V after attention.""" - def __init__(self, num_features, num_heads, expansion_factor=4): + def __init__(self, num_features, expansion_factor=4): super().__init__() self.num_features = num_features - self.num_heads = num_heads self.expansion_factor = expansion_factor - # Layer norm before the self attention - self.layer_norm_1 = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) - self.attention = SelfAttention(self.num_features, self.num_heads) - # Layer norm before the MLP - self.layer_norm_2 = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) - # MLP layers. The MLP expands and then contracts the features. + # AdaLN MLP for post-attention. Initialized to zero so modulation acts as identity at initialization. + self.adaLN_mlp_linear = nn.Linear(self.num_features, 4 * self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + # Linear layer to process v + self.linear_v = nn.Linear(self.num_features, self.num_features) + # Layernorm for the output + self.output_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + # Transformer style MLP layers self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features) self.nonlinearity = nn.GELU(approximate='tanh') self.linear_2 = nn.Linear(self.expansion_factor * self.num_features, self.num_features) # Initialize all biases to zero nn.init.zeros_(self.linear_1.bias) nn.init.zeros_(self.linear_2.bias) - # AdaLN MLP - self.adaLN_mlp_linear = nn.Linear(self.num_features, 6 * self.num_features, bias=True) - # Initialize the modulations to zero. This will ensure the block acts as identity at initialization - nn.init.zeros_(self.adaLN_mlp_linear.weight) - nn.init.zeros_(self.adaLN_mlp_linear.bias) - self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + # Output MLP + self.output_mlp = nn.Sequential(self.linear_1, self.nonlinearity, self.linear_2) + + def forward(self, v, x, t): + """Forward takes v from self attention and the original sequence x with scalar conditioning t.""" + # Calculate the modulations + mods = self.adaLN_mlp(t).unsqueeze(1).chunk(4, dim=2) + # Postprocess v with linear + gating modulation + y = mods[0] * self.linear_v(v) + y = x + y + # Adaptive layernorm + y = modulate(self.output_norm(y), mods[1], mods[2]) + # Output MLP + y = self.output_mlp(y) + # Gating modulation for the output + y = mods[3] * y + y = x + y + return y - def forward(self, x, c, mask=None): - # Calculate the modulations. Each is shape (B, num_features). - mods = self.adaLN_mlp(c).unsqueeze(1).chunk(6, dim=2) - # Forward, with modulations - y = modulate(self.layer_norm_1(x), mods[0], mods[1]) - y = mods[2] * self.attention(y, mask=mask) - x = x + y - y = modulate(self.layer_norm_2(x), mods[3], mods[4]) - y = self.linear_1(y) - y = self.nonlinearity(y) - y = mods[5] * self.linear_2(y) - x = x + y - return x +class MMDiTBlock(nn.Module): + """Transformer block that supports masking, multimodal attention, and adaptive norms.""" -def get_multidimensional_position_embeddings(position_embeddings, coords): - """Position embeddings are shape (D, T, F). Coords are shape (B, S, D).""" - B, S, D = coords.shape - F = position_embeddings.shape[2] - coords = coords.reshape(B * S, D) - sequenced_embeddings = [position_embeddings[d, coords[:, d]] for d in range(D)] - sequenced_embeddings = torch.stack(sequenced_embeddings, dim=-1) - sequenced_embeddings = sequenced_embeddings.view(B, S, F, D) - return sequenced_embeddings # (B, S, F, D) + def __init__(self, num_features, num_heads, expansion_factor=4, is_last=False): + super().__init__() + self.num_features = num_features + self.num_heads = num_heads + self.expansion_factor = expansion_factor + self.is_last = is_last + # Pre-attention blocks for two modalities + self.pre_attention_block_1 = PreAttentionBlock(self.num_features) + self.pre_attention_block_2 = PreAttentionBlock(self.num_features) + # Self-attention + self.attention = SelfAttention(self.num_features, self.num_heads) + # Post-attention blocks for two modalities + self.post_attention_block_1 = PostAttentionBlock(self.num_features, self.expansion_factor) + if not self.is_last: + self.post_attention_block_2 = PostAttentionBlock(self.num_features, self.expansion_factor) + + def forward(self, x1, x2, t, mask=None): + # Pre-attention for the two modalities + q1, k1, v1 = self.pre_attention_block_1(x1, t) + q2, k2, v2 = self.pre_attention_block_2(x2, t) + # Concat q, k, v along the sequence dimension + q = torch.cat([q1, q2], dim=1) + k = torch.cat([k1, k2], dim=1) + v = torch.cat([v1, v2], dim=1) + # Self-attention + v = self.attention(q, k, v, mask=mask) + # Split the attention output back into the two modalities + seq_len_1, seq_len_2 = x1.size(1), x2.size(1) + y1, y2 = v.split([seq_len_1, seq_len_2], dim=1) + # Post-attention for the two modalities + y1 = self.post_attention_block_1(y1, x1, t) + if not self.is_last: + y2 = self.post_attention_block_2(y2, x2, t) + return y1, y2 class DiffusionTransformer(nn.Module): @@ -177,9 +267,12 @@ def __init__(self, self.conditioning_position_embedding = torch.nn.Parameter(conditioning_position_embedding, requires_grad=True) # Transformer blocks self.transformer_blocks = nn.ModuleList([ - DiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor) - for _ in range(self.num_layers) + MMDiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor) + for _ in range(self.num_layers - 1) ]) + # Turn off post attn layers for conditioning sequence in final block + self.transformer_blocks.append( + MMDiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor, is_last=True)) # Output projection layer self.final_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) self.final_linear = nn.Linear(self.num_features, self.input_features) @@ -194,12 +287,12 @@ def __init__(self, self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) def fsdp_wrap_fn(self, module): - if isinstance(module, DiTBlock): + if isinstance(module, MMDiTBlock): return True return False def activation_checkpointing_fn(self, module): - if isinstance(module, DiTBlock): + if isinstance(module, MMDiTBlock): return True return False @@ -207,8 +300,8 @@ def forward(self, x, input_coords, t, - conditioning=None, - conditioning_coords=None, + conditioning, + conditioning_coords, input_mask=None, conditioning_mask=None, constant_conditioning=None): @@ -228,21 +321,17 @@ def forward(self, else: mask = input_mask - if conditioning is not None: - assert conditioning_coords is not None - # Embed the conditioning - c = self.conditioning_embedding(conditioning) # (B, T2, C) - # Get the conditioning position embeddings and add them to the conditioning - c_position_embeddings = get_multidimensional_position_embeddings(self.conditioning_position_embedding, - conditioning_coords) - c_position_embeddings = c_position_embeddings.sum(dim=-1) # (B, T2, C) - c = c + c_position_embeddings # (B, T2, C) - # Concatenate the input and conditioning sequences - y = torch.cat([y, c], dim=1) # (B, T1 + T2, C) - # Concatenate the masks - if conditioning_mask is None: - conditioning_mask = torch.ones(conditioning.shape[0], conditioning.shape[1], device=conditioning.device) - mask = torch.cat([mask, conditioning_mask], dim=1) # (B, T1 + T2) + # Embed the conditioning + c = self.conditioning_embedding(conditioning) # (B, T2, C) + # Get the conditioning position embeddings and add them to the conditioning + c_position_embeddings = get_multidimensional_position_embeddings(self.conditioning_position_embedding, + conditioning_coords) + c_position_embeddings = c_position_embeddings.sum(dim=-1) # (B, T2, C) + c = c + c_position_embeddings # (B, T2, C) + # Concatenate the masks + if conditioning_mask is None: + conditioning_mask = torch.ones(conditioning.shape[0], conditioning.shape[1], device=conditioning.device) + mask = torch.cat([mask, conditioning_mask], dim=1) # (B, T1 + T2) # Expand the mask to the right shape mask = mask.bool() @@ -254,9 +343,7 @@ def forward(self, # Pass through the transformer blocks for block in self.transformer_blocks: - y = block(y, t, mask=mask) - # Throw away the conditioning tokens - y = y[:, 0:x.shape[1], :] + y, c = block(y, c, t, mask=mask) # Pass through the output layers to get the right number of elements mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) y = modulate(self.final_norm(y), mods[0], mods[1]) @@ -264,41 +351,7 @@ def forward(self, return y -def patchify(latents, patch_size): - """Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size].""" - # Assume img is a tensor of shape [B, C, H, W] - B, C, H, W = latents.shape - assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' - # Reshape and permute to get non-overlapping patches - num_H_patches = H // patch_size - num_W_patches = W // patch_size - patches = latents.reshape(B, C, num_H_patches, patch_size, num_W_patches, patch_size) - patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * patch_size * patch_size) - # Generate coordinates for each patch - coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) - coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) - return patches, coords - - -def unpatchify(patches, coords, patch_size): - """Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W].""" - # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] - C = patches.shape[1] // (patch_size * patch_size) - # Calculate the height and width of the original image from the coordinates - H = coords[:, 0].max() * patch_size + patch_size - W = coords[:, 1].max() * patch_size + patch_size - # Initialize an empty tensor for the reconstructed image - img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) - # Iterate over the patches and their coordinates - for patch, (y, x) in zip(patches, patch_size * coords): - # Reshape the patch to [C, patch_size, patch_size] - patch = patch.view(C, patch_size, patch_size) - # Place the patch in the corresponding location in the image - img[:, y:y + patch_size, x:x + patch_size] = patch - return img - - -class ComposerTextToImageDiT(ComposerModel): +class ComposerTextToImageMMDiT(ComposerModel): """ComposerModel for text to image with a diffusion transformer. Args: @@ -381,14 +434,20 @@ def __init__( self.autoencoder._fsdp_wrap = False self.text_encoder._fsdp_wrap = False - # Params for MFU computation + # Param counts relevant for MFU computation # First calc the AdaLN params separately self.adaLN_params = sum(p.numel() for n, p in self.model.named_parameters() if 'adaLN_mlp_linear' in n) - self.n_params = sum(p.numel() for p in self.model.parameters()) - self.n_params -= self.adaLN_params - # Subtract off the embedding params - self.n_params -= self.model.input_position_embedding.numel() - self.n_params -= self.model.conditioning_position_embedding.numel() + # For MFU calc we must be careful to prevent double counting of MMDiT flops. + # Here, count the number of params applied to each sequence element. + # Last block must be handled differently since post attn layers don't run on conditioning sequence + self.n_seq_params_per_block = self.model.num_features**2 * (4 + 2 * self.model.expansion_factor) + self.n_seq_params = self.n_seq_params_per_block * (self.model.num_layers - 1) + self.n_seq_params += 3 * (self.model.num_features**2) + self.n_last_layer_params = self.model.num_features**2 * (1 + 2 * self.model.expansion_factor) + # Params only on the input sequence + self.n_input_params = self.model.input_features * self.model.num_features + # Params only on the conditioning sequence + self.n_cond_params = self.model.conditioning_features * self.model.num_features # Set up metrics self.train_metrics = [MeanSquaredError()] @@ -398,7 +457,7 @@ def __init__( self.rng_generator: Optional[torch.Generator] = None def _apply(self, fn): - super(ComposerTextToImageDiT, self)._apply(fn) + super(ComposerTextToImageMMDiT, self)._apply(fn) self.latent_mean = fn(self.latent_mean) self.latent_std = fn(self.latent_std) return self @@ -413,8 +472,14 @@ def flops_per_batch(self, batch): input_seq_len = height * width / (self.patch_size**2 * self.downsample_factor**2) cond_seq_len = batch[self.caption_key].shape[1] seq_len = input_seq_len + cond_seq_len - # Calulate forward flops excluding attention - param_flops = 2 * self.n_params * batch_size * seq_len + # Calulate forward flops on full sequence excluding attention + param_flops = 2 * self.n_seq_params * batch_size * seq_len + # Last block contributes a bit less than other blocks + param_flops += 2 * self.n_last_layer_params * batch_size * input_seq_len + # Include input sequence params (comparatively small) + param_flops += 2 * self.n_input_params * batch_size * input_seq_len + # Include conditioning sequence params (comparatively small) + param_flops += 2 * self.n_cond_params * batch_size * cond_seq_len # Include flops from adaln param_flops += 2 * self.adaLN_params * batch_size # Calculate flops for attention layers diff --git a/diffusion/train.py b/diffusion/train.py index d8b6ce84..33766c3b 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -20,7 +20,7 @@ from torch.optim import Optimizer from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder -from diffusion.models.transformer import ComposerTextToImageDiT +from diffusion.models.transformer import ComposerTextToImageMMDiT def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: @@ -54,10 +54,10 @@ def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Opti def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: """Configures the optimizer for use with a transformer model.""" print('Configuring optimizer for transformer') - assert isinstance(model, ComposerTextToImageDiT) + assert isinstance(model, ComposerTextToImageMMDiT) # Turn off weight decay for the positional embeddings - no_decay = ['bias', 'layer_norm', 'position_embedding'] + no_decay = ['bias', 'norm', 'position_embedding'] params_with_no_decay = [] params_with_decay = [] for name, param in model.named_parameters(): @@ -93,7 +93,7 @@ def train(config: DictConfig) -> None: # Check if this is training an autoencoder. If so, the optimizer needs different param groups optimizer = make_autoencoder_optimizer(config, model) tokenizer = None - elif isinstance(model, ComposerTextToImageDiT): + elif isinstance(model, ComposerTextToImageMMDiT): # Check if this is training a transformer. If so, the optimizer needs different param groups optimizer = make_transformer_optimizer(config, model) tokenizer = model.tokenizer From d4401425359b4529e8e2edd45a5fa60d9d2e315c Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 20 Jun 2024 23:23:59 +0000 Subject: [PATCH 11/27] Simplify mask logic --- diffusion/models/transformer.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index fc3933fe..91d83f6d 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -336,8 +336,7 @@ def forward(self, # Expand the mask to the right shape mask = mask.bool() mask = mask.unsqueeze(-1) & mask.unsqueeze(1) # (B, T1 + T2, T1 + T2) - identity = torch.eye(mask.shape[1], device=mask.device, - dtype=mask.dtype).unsqueeze(0).expand(mask.shape[0], -1, -1) + identity = torch.eye(mask.shape[1], device=mask.device, dtype=mask.dtype).unsqueeze(0) mask = mask | identity mask = mask.unsqueeze(1) # (B, 1, T1 + T2, T1 + T2) @@ -396,6 +395,7 @@ def __init__( image_key: str = 'image', caption_key: str = 'caption', caption_mask_key: str = 'caption_mask', + use_pooled_embedding: bool = False, ): super().__init__() self.model = model @@ -419,9 +419,11 @@ def __init__( self.image_key = image_key self.caption_key = caption_key self.caption_mask_key = caption_mask_key + self.use_pooled_embedding = use_pooled_embedding # Projection layer for the pooled text embeddings - self.pooled_projection_layer = nn.Linear(self.model.conditioning_features, self.model.num_features) + if self.use_pooled_embedding: + self.pooled_projection_layer = nn.Linear(self.model.conditioning_features, self.model.num_features) # freeze text_encoder during diffusion training and use half precision self.autoencoder.requires_grad_(False) @@ -432,7 +434,7 @@ def __init__( # Only FSDP wrap models we are training self.model._fsdp_wrap = True self.autoencoder._fsdp_wrap = False - self.text_encoder._fsdp_wrap = False + self.text_encoder._fsdp_wrap = True # Param counts relevant for MFU computation # First calc the AdaLN params separately @@ -512,21 +514,33 @@ def diffusion_forward_process(self, inputs: torch.Tensor): def forward(self, batch): # Get the inputs image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] + # Get the text embeddings and image latents with torch.cuda.amp.autocast(enabled=False): latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data text_encoder_out = self.text_encoder(caption, attention_mask=caption_mask) text_embeddings = text_encoder_out[0] + # Ensure text embeddings are not longer than the model can handle + if text_embeddings.shape[1] > self.model.conditioning_max_sequence_length: + text_embeddings = text_embeddings[:, :self.model.conditioning_max_sequence_length] + caption_mask = caption_mask[:, :self.model.conditioning_max_sequence_length] + + # Optionally use pooled embeddings + if self.use_pooled_embedding: pooled_text_embeddings = text_encoder_out[1] + pooled_text_embeddings = self.pooled_projection_layer(pooled_text_embeddings) + else: + pooled_text_embeddings = None + # Make the text embedding coords text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1).unsqueeze(-1) # Project the text embeddings - pooled_text_embeddings = self.pooled_projection_layer(pooled_text_embeddings) # Zero dropped captions if needed if 'drop_caption_mask' in batch.keys(): text_embeddings *= batch['drop_caption_mask'].view(-1, 1, 1) - pooled_text_embeddings *= batch['drop_caption_mask'].view(-1, 1) + if self.use_pooled_embedding: + pooled_text_embeddings *= batch['drop_caption_mask'].view(-1, 1) # Scale and patchify the latents latents = (latents - self.latent_mean) / self.latent_std latent_patches, latent_coords = patchify(latents, self.patch_size) @@ -649,6 +663,12 @@ def generate(self, # scale the initial noise by the standard deviation required by the scheduler latent_patches = latent_patches * self.inference_scheduler.init_noise_sigma + # Ensure text embeddings, mask, and coords are not longer than the model can handle + if text_embeddings.shape[1] > self.model.conditioning_max_sequence_length: + text_embeddings = text_embeddings[:, :self.model.conditioning_max_sequence_length] + text_embeddings_coords = text_embeddings_coords[:, :self.model.conditioning_max_sequence_length] + text_embeddings_mask = text_embeddings_mask[:, :self.model.conditioning_max_sequence_length] + # backward diffusion process for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) From de3d50e30cec42db1260ef4b5cfa0c6191492dbe Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 24 Jun 2024 07:11:16 +0000 Subject: [PATCH 12/27] Rectified flows and pooled embeddings --- diffusion/models/models.py | 71 ++++------- diffusion/models/transformer.py | 211 ++++++++++++++++++-------------- 2 files changed, 141 insertions(+), 141 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 706bfe9b..7527a865 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -498,26 +498,26 @@ def stable_diffusion_xl( def text_to_image_transformer( - tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', - 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), - text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', - 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), - vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', - autoencoder_path: Optional[str] = None, - autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', - num_features: int = 1152, - num_heads: int = 16, - num_layers: int = 28, - input_max_sequence_length: int = 1024, - conditioning_features: int = 768, - conditioning_max_sequence_length: int = 77, - patch_size: int = 2, - prediction_type: str = 'epsilon', - latent_mean: Union[float, Tuple, str] = 0.0, - latent_std: Union[float, Tuple, str] = 7.67754318618, - beta_schedule: str = 'scaled_linear', - zero_terminal_snr: bool = False, - use_karras_sigmas: bool = False): + tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', + 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), + text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', + 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), + vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', + autoencoder_path: Optional[str] = None, + autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', + num_features: int = 1152, + num_heads: int = 16, + num_layers: int = 28, + input_max_sequence_length: int = 1024, + conditioning_features: int = 768, + conditioning_max_sequence_length: int = 77, + patch_size: int = 2, + latent_mean: Union[float, Tuple, str] = 0.0, + latent_std: Union[float, Tuple, str] = 7.67754318618, + timestep_mean: float = 0.0, + timestep_std: float = 1.0, + timestep_shift: float = 1.0, +): """Text to image transformer training setup.""" latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -562,31 +562,6 @@ def text_to_image_transformer( latent_std = (latent_std,) * autoencoder_channels assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) - # Make the noise schedulers - noise_scheduler = DDPMScheduler(num_train_timesteps=1000, - beta_start=0.0000085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - variance_type='fixed_small', - clip_sample=False, - prediction_type=prediction_type, - sample_max_value=1.0, - timestep_spacing='leading', - steps_offset=1, - rescale_betas_zero_snr=zero_terminal_snr) - inference_noise_scheduler = EulerDiscreteScheduler(num_train_timesteps=1000, - beta_start=0.0000085, - beta_end=0.012, - beta_schedule=beta_schedule, - trained_betas=None, - prediction_type=prediction_type, - interpolation_type='linear', - use_karras_sigmas=use_karras_sigmas, - timestep_spacing='leading', - steps_offset=1, - rescale_betas_zero_snr=zero_terminal_snr) - # Make the transformer model transformer = DiffusionTransformer(num_features=num_features, num_heads=num_heads, @@ -603,14 +578,14 @@ def text_to_image_transformer( autoencoder=vae, text_encoder=text_encoder, tokenizer=tokenizer, - noise_scheduler=noise_scheduler, - inference_noise_scheduler=inference_noise_scheduler, - prediction_type=prediction_type, latent_mean=latent_mean, latent_std=latent_std, patch_size=patch_size, downsample_factor=downsample_factor, latent_channels=autoencoder_channels, + timestep_mean=timestep_mean, + timestep_std=timestep_std, + timestep_shift=timestep_shift, image_key='image', caption_key='captions', caption_mask_key='attention_mask') diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 91d83f6d..a432de41 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -19,25 +19,6 @@ def modulate(x, shift, scale): return x * (1.0 + scale) + shift -def timestep_embedding(timesteps, dim, max_period=10000): - """Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / - half).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def get_multidimensional_position_embeddings(position_embeddings, coords): """Position embeddings are shape (D, T, F). Coords are shape (B, S, D).""" B, S, D = coords.shape @@ -83,6 +64,56 @@ def unpatchify(patches, coords, patch_size): return img +class ScalarEmbedding(nn.Module): + """Embedding block for scalars.""" + + def __init__(self, num_features, sinusoidal_embedding_dim=256): + super().__init__() + self.num_features = num_features + self.sinusoidal_embedding_dim = sinusoidal_embedding_dim + self.linear_1 = nn.Linear(self.sinusoidal_embedding_dim, self.num_features) + self.linear_2 = nn.Linear(self.num_features, self.num_features) + self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) + + @staticmethod + def timestep_embedding(timesteps, dim, max_period=10000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / + half).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, x): + sinusoidal_embedding = self.timestep_embedding(x, self.sinusoidal_embedding_dim) + return self.mlp(sinusoidal_embedding) + + +class VectorEmbedding(nn.Module): + """Embedding block for vectors.""" + + def __init__(self, input_features, num_features): + super().__init__() + self.input_features = input_features + self.num_features = num_features + self.linear_1 = nn.Linear(self.input_features, self.num_features) + self.linear_2 = nn.Linear(self.num_features, self.num_features) + self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) + + def forward(self, x): + return self.mlp(x) + + class PreAttentionBlock(nn.Module): """Block to compute QKV before attention.""" @@ -252,6 +283,8 @@ def __init__(self, self.conditioning_dimension = conditioning_dimension self.conditioning_max_sequence_length = conditioning_max_sequence_length + # Embedding block for the timestep + self.timestep_embedding = ScalarEmbedding(self.num_features) # Projection layer for the input sequence self.input_embedding = nn.Linear(self.input_features, self.num_features) # Embedding layer for the input sequence @@ -306,8 +339,8 @@ def forward(self, conditioning_mask=None, constant_conditioning=None): # Embed the timestep - t = timestep_embedding(t, self.num_features) - # Optionally add constant conditioning + t = self.timestep_embedding(t) + # Optionally add constant conditioning. This assumes it has been embedded already. if constant_conditioning is not None: t = t + constant_conditioning # Embed the input @@ -384,29 +417,24 @@ def __init__( autoencoder: torch.nn.Module, text_encoder: torch.nn.Module, tokenizer, - noise_scheduler, - inference_noise_scheduler, - prediction_type: str = 'epsilon', latent_mean: Optional[tuple[float]] = None, latent_std: Optional[tuple[float]] = None, patch_size: int = 2, downsample_factor: int = 8, latent_channels: int = 4, + timestep_mean: float = 0.0, + timestep_std: float = 1.0, + timestep_shift: float = 1.0, image_key: str = 'image', caption_key: str = 'caption', caption_mask_key: str = 'caption_mask', - use_pooled_embedding: bool = False, + pooled_embedding_features: int = 768, ): super().__init__() self.model = model self.autoencoder = autoencoder self.text_encoder = text_encoder self.tokenizer = tokenizer - self.noise_scheduler = noise_scheduler - self.inference_scheduler = inference_noise_scheduler - self.prediction_type = prediction_type.lower() - if self.prediction_type not in ['epsilon', 'sample', 'v_prediction']: - raise ValueError(f'Unrecognized prediction type {self.prediction_type}') if latent_mean is None: self.latent_mean = 4 * (0.0) if latent_std is None: @@ -416,14 +444,16 @@ def __init__( self.patch_size = patch_size self.downsample_factor = downsample_factor self.latent_channels = latent_channels + self.timestep_mean = timestep_mean + self.timestep_std = timestep_std + self.timestep_shift = timestep_shift self.image_key = image_key self.caption_key = caption_key self.caption_mask_key = caption_mask_key - self.use_pooled_embedding = use_pooled_embedding + self.pooled_embedding_features = pooled_embedding_features - # Projection layer for the pooled text embeddings - if self.use_pooled_embedding: - self.pooled_projection_layer = nn.Linear(self.model.conditioning_features, self.model.num_features) + # Embeeding MLP for the pooled text embeddings + self.pooled_embedding_mlp = VectorEmbedding(pooled_embedding_features, model.num_features) # freeze text_encoder during diffusion training and use half precision self.autoencoder.requires_grad_(False) @@ -489,58 +519,43 @@ def flops_per_batch(self, batch): return 3 * param_flops + 3 * attention_flops def diffusion_forward_process(self, inputs: torch.Tensor): - """Diffusion forward process.""" - # Sample a timestep for every element in the batch - timesteps = torch.randint(0, - len(self.noise_scheduler), (inputs.shape[0],), - device=inputs.device, - generator=self.rng_generator) - # Generate the noise, applied to the whole input sequence + """Diffusion forward process using a rectified flow.""" + # First, sample timesteps according to a logit-normal distribution + u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator) + u = self.timestep_mean + self.timestep_std * u + timesteps = torch.sigmoid(u).view(-1, 1, 1) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + # Then, add the noise to the latents according to the recitified flow noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) - # Add the noise to the latents according to the schedule - noised_inputs = self.noise_scheduler.add_noise(inputs, noise, timesteps) - # Generate the targets - if self.prediction_type == 'epsilon': - targets = noise - elif self.prediction_type == 'sample': - targets = inputs - elif self.prediction_type == 'v_prediction': - targets = self.noise_scheduler.get_velocity(inputs, noise, timesteps) - else: - raise ValueError( - f'prediction type must be one of sample, epsilon, or v_prediction. Got {self.prediction_type}') - return noised_inputs, targets, timesteps + noised_inputs = (1 - timesteps) * inputs + timesteps * noise + # Compute the targets, which are the velocities + targets = noise - inputs + return noised_inputs, targets, timesteps[:, 0, 0] def forward(self, batch): # Get the inputs image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] - + # Get the text embeddings and image latents with torch.cuda.amp.autocast(enabled=False): latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data text_encoder_out = self.text_encoder(caption, attention_mask=caption_mask) text_embeddings = text_encoder_out[0] + pooled_text_embeddings = text_encoder_out[1] # Ensure text embeddings are not longer than the model can handle if text_embeddings.shape[1] > self.model.conditioning_max_sequence_length: text_embeddings = text_embeddings[:, :self.model.conditioning_max_sequence_length] caption_mask = caption_mask[:, :self.model.conditioning_max_sequence_length] - # Optionally use pooled embeddings - if self.use_pooled_embedding: - pooled_text_embeddings = text_encoder_out[1] - pooled_text_embeddings = self.pooled_projection_layer(pooled_text_embeddings) - else: - pooled_text_embeddings = None - + # Encode the pooled embeddings + pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) # Make the text embedding coords text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1).unsqueeze(-1) - # Project the text embeddings # Zero dropped captions if needed if 'drop_caption_mask' in batch.keys(): text_embeddings *= batch['drop_caption_mask'].view(-1, 1, 1) - if self.use_pooled_embedding: - pooled_text_embeddings *= batch['drop_caption_mask'].view(-1, 1) + pooled_text_embeddings *= batch['drop_caption_mask'].view(-1, 1) # Scale and patchify the latents latents = (latents - self.latent_mean) / self.latent_std latent_patches, latent_coords = patchify(latents, self.patch_size) @@ -601,13 +616,25 @@ def embed_prompt(self, prompt): return_tensors='pt') tokenized_prompts = tokenized_out['input_ids'].to(self.text_encoder.device) prompt_mask = tokenized_out['attention_mask'].to(self.text_encoder.device) - text_embeddings = self.text_encoder(tokenized_prompts, attention_mask=prompt_mask)[0] + text_encoder_out = self.text_encoder(tokenized_prompts, attention_mask=prompt_mask) + text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] prompt_mask = self.combine_attention_masks(prompt_mask) - return text_embeddings, prompt_mask + return text_embeddings, prompt_mask, pooled_text_embeddings + + def make_text_embeddings_coords(self, text_embeddings): + text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) + text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) + text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) + return text_embeddings_coords + + def make_sampling_timesteps(self, N: int): + timesteps = torch.linspace(1, 0, N) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + return timesteps @torch.no_grad() def generate(self, - prompt: Optional[list] = None, + prompt: list, negative_prompt: Optional[list] = None, height: int = 256, width: int = 256, @@ -623,24 +650,17 @@ def generate(self, if seed: rng_generator = rng_generator.manual_seed(seed) - # Get the text embeddings - if prompt is not None: - text_embeddings, prompt_mask = self.embed_prompt(prompt) - text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) - text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) - text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) - else: - raise ValueError('Prompt must be specified') + # Get the text embeddings and their coords + text_embeddings, prompt_mask, pooled_embedding = self.embed_prompt(prompt) + text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) + # Create the negative prompt if it exists, or use all zeros if it doesn't if negative_prompt is not None: - negative_text_embeddings, negative_prompt_mask = self.embed_prompt(negative_prompt) + negative_text_embeddings, negative_prompt_mask, pooled_neg_embedding = self.embed_prompt(negative_prompt) else: negative_text_embeddings = torch.zeros_like(text_embeddings) negative_prompt_mask = torch.zeros_like(prompt_mask) - negative_text_embeddings_coords = torch.arange(negative_text_embeddings.shape[1], - device=negative_text_embeddings.device) - negative_text_embeddings_coords = negative_text_embeddings_coords.unsqueeze(0).expand( - negative_text_embeddings.shape[0], -1) - negative_text_embeddings_coords = negative_text_embeddings_coords.unsqueeze(-1) + pooled_neg_embedding = torch.zeros_like(pooled_embedding) + negative_text_embeddings_coords = self.make_text_embeddings_coords(negative_text_embeddings) # Generate initial noise latent_height = height // self.downsample_factor @@ -656,36 +676,41 @@ def generate(self, text_embeddings = torch.cat([text_embeddings, negative_text_embeddings], dim=0) text_embeddings_coords = torch.cat([text_embeddings_coords, negative_text_embeddings_coords], dim=0) text_embeddings_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0) + pooled_embedding = torch.cat([pooled_embedding, pooled_neg_embedding], dim=0) latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) - # Prep for reverse process - self.inference_scheduler.set_timesteps(num_inference_steps) - # scale the initial noise by the standard deviation required by the scheduler - latent_patches = latent_patches * self.inference_scheduler.init_noise_sigma - # Ensure text embeddings, mask, and coords are not longer than the model can handle if text_embeddings.shape[1] > self.model.conditioning_max_sequence_length: text_embeddings = text_embeddings[:, :self.model.conditioning_max_sequence_length] text_embeddings_coords = text_embeddings_coords[:, :self.model.conditioning_max_sequence_length] text_embeddings_mask = text_embeddings_mask[:, :self.model.conditioning_max_sequence_length] + # Encode the pooled embeddings + pooled_embedding = self.pooled_embedding_mlp(pooled_embedding) + # backward diffusion process - for t in tqdm(self.inference_scheduler.timesteps, disable=not progress_bar): + timesteps = self.make_sampling_timesteps(num_inference_steps).to(device) + for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) - latent_patches_input = self.inference_scheduler.scale_model_input(latent_patches_input, t) # Get the model prediction model_out = self.model(latent_patches_input, latent_coords_input, - t.unsqueeze(0).to(device), + t.unsqueeze(0), conditioning=text_embeddings, conditioning_coords=text_embeddings_coords, input_mask=None, - conditioning_mask=text_embeddings_mask) + conditioning_mask=text_embeddings_mask, + constant_conditioning=pooled_embedding) # Do CFG pred_cond, pred_uncond = model_out.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) - # Update the inputs - latent_patches = self.inference_scheduler.step(pred, t, latent_patches, generator=rng_generator).prev_sample + # compute the time delta. + if i < len(timesteps) - 1: + delta_t = timesteps[i] - timesteps[(i + 1)] + else: + delta_t = timesteps[i] + # Update the latents + latent_patches = latent_patches - pred * delta_t # Unpatchify the latents latents = [ unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) From 2ad5fc26c9de047ee0b83240061737243dcdf6aa Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Tue, 25 Jun 2024 18:59:36 +0000 Subject: [PATCH 13/27] Prep for inference --- diffusion/inference/__init__.py | 4 +- diffusion/inference/inference_model.py | 78 ++++++++++++++++++++++++++ diffusion/models/__init__.py | 4 +- diffusion/models/models.py | 6 +- diffusion/models/transformer.py | 8 ++- 5 files changed, 91 insertions(+), 9 deletions(-) diff --git a/diffusion/inference/__init__.py b/diffusion/inference/__init__.py index 34dc748c..00acafed 100644 --- a/diffusion/inference/__init__.py +++ b/diffusion/inference/__init__.py @@ -3,6 +3,6 @@ """Inference endpoint.""" -from diffusion.inference.inference_model import StableDiffusionInference, StableDiffusionXLInference +from diffusion.inference.inference_model import ModelInference, StableDiffusionInference, StableDiffusionXLInference -__all__ = ['StableDiffusionInference', 'StableDiffusionXLInference'] +__all__ = ['ModelInference', 'StableDiffusionInference', 'StableDiffusionXLInference'] diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index 68fa86ee..d6925baf 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -11,6 +11,7 @@ from composer.utils.file_helpers import get_file from PIL import Image +import diffusion.models from diffusion.models import stable_diffusion_2, stable_diffusion_xl # Local checkpoint params @@ -225,3 +226,80 @@ def predict(self, model_requests: List[Dict[str, Any]]): base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') png_images.append(base64_encoded_image) return png_images + + +class ModelInference(): + """Generic inference endpoint class for diffusion models with a model.generate() method. + + Args: + model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'. + local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'. + strict (bool): Whether to load the model weights strictly. Default: False. + **model_kwargs: Keyword arguments to pass to the model initialization. + """ + + def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs): + self.device = torch.cuda.current_device() + model_factory = getattr(diffusion.models, model_name) + model = model_factory(**model_kwargs) + + if 'pretrained' in model_kwargs and model_kwargs['pretrained']: + pass + else: + state_dict = torch.load(local_checkpoint_path) + for key in list(state_dict['state']['model'].keys()): + if 'val_metrics.' in key: + del state_dict['state']['model'][key] + model.load_state_dict(state_dict['state']['model'], strict=strict) + model.to(self.device) + self.model = model.eval() + + def predict(self, model_requests: List[Dict[str, Any]]): + prompts = [] + negative_prompts = [] + generate_kwargs = {} + + # assumes the same generate_kwargs across all samples + for req in model_requests: + if 'input' not in req: + raise RuntimeError('"input" must be provided to generate call') + inputs = req['input'] + + # Prompts and negative prompts if available + if isinstance(inputs, str): + prompts.append(inputs) + elif isinstance(inputs, Dict): + if 'prompt' not in inputs: + raise RuntimeError('"prompt" must be provided to generate call if using a dict as input') + prompts.append(inputs['prompt']) + if 'negative_prompt' in inputs: + negative_prompts.append(inputs['negative_prompt']) + else: + raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}') + + generate_kwargs = req['parameters'] + + # Check for prompts + if len(prompts) == 0: + raise RuntimeError('No prompts provided, must be either a string or dictionary with "prompt"') + + # Check negative prompt length + if len(negative_prompts) == 0: + negative_prompts = None + elif len(prompts) != len(negative_prompts): + raise RuntimeError('There must be the same number of negative prompts as prompts.') + + # Generate images + with torch.cuda.amp.autocast(True): + imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu() + + # Send as bytes + png_images = [] + for i in range(imgs.shape[0]): + img = (imgs[i].permute(1, 2, 0).numpy() * 255).round().astype('uint8') + pil_image = Image.fromarray(img, 'RGB') + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format='PNG') + base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') + png_images.append(base64_encoded_image) + return png_images diff --git a/diffusion/models/__init__.py b/diffusion/models/__init__.py index 0928c83a..1a4bb0a2 100644 --- a/diffusion/models/__init__.py +++ b/diffusion/models/__init__.py @@ -4,7 +4,8 @@ """Diffusion models.""" from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion, - discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl) + discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl, + text_to_image_transformer) from diffusion.models.noop import NoOpModel from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion @@ -19,4 +20,5 @@ 'stable_diffusion_2', 'stable_diffusion_xl', 'StableDiffusion', + 'text_to_image_transformer', ] diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 7527a865..0e113d31 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -498,10 +498,8 @@ def stable_diffusion_xl( def text_to_image_transformer( - tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer', - 'stabilityai/stable-diffusion-xl-base-1.0/tokenizer_2'), - text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder', - 'stabilityai/stable-diffusion-xl-base-1.0/text_encoder_2'), + tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'), + text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder'), vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', autoencoder_path: Optional[str] = None, autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index a432de41..74bca7df 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -452,7 +452,7 @@ def __init__( self.caption_mask_key = caption_mask_key self.pooled_embedding_features = pooled_embedding_features - # Embeeding MLP for the pooled text embeddings + # Embedding MLP for the pooled text embeddings self.pooled_embedding_mlp = VectorEmbedding(pooled_embedding_features, model.num_features) # freeze text_encoder during diffusion training and use half precision @@ -535,7 +535,6 @@ def diffusion_forward_process(self, inputs: torch.Tensor): def forward(self, batch): # Get the inputs image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] - # Get the text embeddings and image latents with torch.cuda.amp.autocast(enabled=False): latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data @@ -641,6 +640,7 @@ def generate(self, guidance_scale: float = 7.0, rescaled_guidance: Optional[float] = None, num_inference_steps: int = 50, + num_images_per_prompt: int = 1, progress_bar: bool = True, seed: Optional[int] = None): """Generate from the model.""" @@ -650,11 +650,15 @@ def generate(self, if seed: rng_generator = rng_generator.manual_seed(seed) + # Duplicate the images in the prompt if needed. + prompt = [item for item in prompt for _ in range(num_images_per_prompt)] # Get the text embeddings and their coords text_embeddings, prompt_mask, pooled_embedding = self.embed_prompt(prompt) text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) # Create the negative prompt if it exists, or use all zeros if it doesn't if negative_prompt is not None: + # Duplicate the images in the negative prompt if needed. + negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] negative_text_embeddings, negative_prompt_mask, pooled_neg_embedding = self.embed_prompt(negative_prompt) else: negative_text_embeddings = torch.zeros_like(text_embeddings) From 499e7eeba51128cd073c1d584eb490c5e2e31013 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 26 Jun 2024 17:56:01 +0000 Subject: [PATCH 14/27] Pooled embeddings should be zeroed after embedding for cfg --- diffusion/models/transformer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 74bca7df..41549a16 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -691,7 +691,9 @@ def generate(self, # Encode the pooled embeddings pooled_embedding = self.pooled_embedding_mlp(pooled_embedding) - + # Zero out the embedded pooled embeddings for the negative prompt if there isn't one + if negative_prompt is None: + pooled_embedding[len(prompt):] *= 0.0 # backward diffusion process timesteps = self.make_sampling_timesteps(num_inference_steps).to(device) for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): From d06580ac8f0e975d737bf99e5973ab20fe864f98 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 26 Jun 2024 21:19:43 +0000 Subject: [PATCH 15/27] Use shared functions to reduce error surface --- diffusion/models/transformer.py | 174 +++++++++++++++----------------- 1 file changed, 83 insertions(+), 91 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 41549a16..7cc540f3 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -518,6 +518,67 @@ def flops_per_batch(self, batch): attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size return 3 * param_flops + 3 * attention_flops + def encode_image(self, image): + with torch.cuda.amp.autocast(enabled=False): + latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data + # Scale and patchify the latents + latents = (latents - self.latent_mean) / self.latent_std + latent_patches, latent_coords = patchify(latents, self.patch_size) + return latent_patches, latent_coords + + @torch.no_grad() + def decode_image(self, latent_patches, latent_coords): + # Unpatchify the latents + latents = [ + unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) + ] + latents = torch.stack(latents) + # Scale the latents back to the original scale + latents = latents * self.latent_std + self.latent_mean + # Decode the latents + with torch.cuda.amp.autocast(enabled=False): + image = self.autoencoder.decode(latents.half()).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image + + def tokenize_prompts(self, prompts): + tokenized_out = self.tokenizer(prompts, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + return tokenized_out['input_ids'], tokenized_out['attention_mask'] + + def combine_attention_masks(self, attention_masks): + if len(attention_masks.shape) == 2: + return attention_masks + elif len(attention_masks.shape) == 3: + encoder_attention_masks = attention_masks[:, 0] + for i in range(1, attention_masks.shape[1]): + encoder_attention_masks |= attention_masks[:, i] + return encoder_attention_masks + else: + raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_masks.shape}') + + def make_text_embeddings_coords(self, text_embeddings): + text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) + text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) + text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) + return text_embeddings_coords + + def embed_tokenized_prompts(self, tokenized_prompts, attention_masks): + with torch.cuda.amp.autocast(enabled=False): + # Ensure text embeddings are not longer than the model can handle + if tokenized_prompts.shape[1] > self.model.conditioning_max_sequence_length: + tokenized_prompts = tokenized_prompts[:, :self.model.conditioning_max_sequence_length] + text_encoder_out = self.text_encoder(tokenized_prompts, attention_mask=attention_masks) + text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] + text_mask = self.combine_attention_masks(attention_masks) + text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) + # Encode the pooled embeddings + pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) + return text_embeddings, text_embeddings_coords, text_mask, pooled_text_embeddings + def diffusion_forward_process(self, inputs: torch.Tensor): """Diffusion forward process using a rectified flow.""" # First, sample timesteps according to a logit-normal distribution @@ -535,29 +596,11 @@ def diffusion_forward_process(self, inputs: torch.Tensor): def forward(self, batch): # Get the inputs image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] - # Get the text embeddings and image latents - with torch.cuda.amp.autocast(enabled=False): - latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data - text_encoder_out = self.text_encoder(caption, attention_mask=caption_mask) - text_embeddings = text_encoder_out[0] - pooled_text_embeddings = text_encoder_out[1] - # Ensure text embeddings are not longer than the model can handle - if text_embeddings.shape[1] > self.model.conditioning_max_sequence_length: - text_embeddings = text_embeddings[:, :self.model.conditioning_max_sequence_length] - caption_mask = caption_mask[:, :self.model.conditioning_max_sequence_length] - - # Encode the pooled embeddings - pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) - # Make the text embedding coords - text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) - text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1).unsqueeze(-1) - # Zero dropped captions if needed - if 'drop_caption_mask' in batch.keys(): - text_embeddings *= batch['drop_caption_mask'].view(-1, 1, 1) - pooled_text_embeddings *= batch['drop_caption_mask'].view(-1, 1) - # Scale and patchify the latents - latents = (latents - self.latent_mean) / self.latent_std - latent_patches, latent_coords = patchify(latents, self.patch_size) + # Get the image latents + latent_patches, latent_coords = self.encode_image(image) + # Get the text embeddings and their coords + text_embeddings, text_embeddings_coords, caption_mask, pooled_text_embeddings = self.embed_tokenized_prompts( + caption, caption_mask) # Diffusion forward process noised_inputs, targets, timesteps = self.diffusion_forward_process(latent_patches) # Forward through the model @@ -595,37 +638,6 @@ def update_metric(self, batch, outputs, metric): else: raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') - def combine_attention_masks(self, attention_mask): - if len(attention_mask.shape) == 2: - return attention_mask - elif len(attention_mask.shape) == 3: - encoder_attention_mask = attention_mask[:, 0] - for i in range(1, attention_mask.shape[1]): - encoder_attention_mask |= attention_mask[:, i] - return encoder_attention_mask - else: - raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_mask.shape}') - - def embed_prompt(self, prompt): - with torch.cuda.amp.autocast(enabled=False): - tokenized_out = self.tokenizer(prompt, - padding='max_length', - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors='pt') - tokenized_prompts = tokenized_out['input_ids'].to(self.text_encoder.device) - prompt_mask = tokenized_out['attention_mask'].to(self.text_encoder.device) - text_encoder_out = self.text_encoder(tokenized_prompts, attention_mask=prompt_mask) - text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] - prompt_mask = self.combine_attention_masks(prompt_mask) - return text_embeddings, prompt_mask, pooled_text_embeddings - - def make_text_embeddings_coords(self, text_embeddings): - text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) - text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) - text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) - return text_embeddings_coords - def make_sampling_timesteps(self, N: int): timesteps = torch.linspace(1, 0, N) timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) @@ -650,21 +662,20 @@ def generate(self, if seed: rng_generator = rng_generator.manual_seed(seed) - # Duplicate the images in the prompt if needed. + # Set default negative prompts to empty string if not provided + if negative_prompt is None: + negative_prompt = ['' for _ in prompt] + # Duplicate the images in the prompt and negative prompt if needed. prompt = [item for item in prompt for _ in range(num_images_per_prompt)] - # Get the text embeddings and their coords - text_embeddings, prompt_mask, pooled_embedding = self.embed_prompt(prompt) - text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) - # Create the negative prompt if it exists, or use all zeros if it doesn't - if negative_prompt is not None: - # Duplicate the images in the negative prompt if needed. - negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] - negative_text_embeddings, negative_prompt_mask, pooled_neg_embedding = self.embed_prompt(negative_prompt) - else: - negative_text_embeddings = torch.zeros_like(text_embeddings) - negative_prompt_mask = torch.zeros_like(prompt_mask) - pooled_neg_embedding = torch.zeros_like(pooled_embedding) - negative_text_embeddings_coords = self.make_text_embeddings_coords(negative_text_embeddings) + negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] + # Tokenize both prompt and negative prompts + prompt_tokens, prompt_mask = self.tokenize_prompts(prompt) + negative_prompt_tokens, negative_prompt_mask = self.tokenize_prompts(negative_prompt) + # Embed the tokenized prompts and negative prompts + text_embeddings, text_embeddings_coords, prompt_mask, pooled_embedding = self.embed_tokenized_prompts( + prompt_tokens, prompt_mask) + neg_text_embeddings, neg_text_embeddings_coords, neg_prompt_mask, pooled_neg_embedding = self.embed_tokenized_prompts( + negative_prompt_tokens, negative_prompt_mask) # Generate initial noise latent_height = height // self.downsample_factor @@ -677,23 +688,12 @@ def generate(self, latent_patches, latent_coords = patchify(latents, self.patch_size) # Set up for CFG - text_embeddings = torch.cat([text_embeddings, negative_text_embeddings], dim=0) - text_embeddings_coords = torch.cat([text_embeddings_coords, negative_text_embeddings_coords], dim=0) - text_embeddings_mask = torch.cat([prompt_mask, negative_prompt_mask], dim=0) + text_embeddings = torch.cat([text_embeddings, neg_text_embeddings], dim=0) + text_embeddings_coords = torch.cat([text_embeddings_coords, neg_text_embeddings_coords], dim=0) + text_embeddings_mask = torch.cat([prompt_mask, neg_prompt_mask], dim=0) pooled_embedding = torch.cat([pooled_embedding, pooled_neg_embedding], dim=0) latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) - # Ensure text embeddings, mask, and coords are not longer than the model can handle - if text_embeddings.shape[1] > self.model.conditioning_max_sequence_length: - text_embeddings = text_embeddings[:, :self.model.conditioning_max_sequence_length] - text_embeddings_coords = text_embeddings_coords[:, :self.model.conditioning_max_sequence_length] - text_embeddings_mask = text_embeddings_mask[:, :self.model.conditioning_max_sequence_length] - - # Encode the pooled embeddings - pooled_embedding = self.pooled_embedding_mlp(pooled_embedding) - # Zero out the embedded pooled embeddings for the negative prompt if there isn't one - if negative_prompt is None: - pooled_embedding[len(prompt):] *= 0.0 # backward diffusion process timesteps = self.make_sampling_timesteps(num_inference_steps).to(device) for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): @@ -717,14 +717,6 @@ def generate(self, delta_t = timesteps[i] # Update the latents latent_patches = latent_patches - pred * delta_t - # Unpatchify the latents - latents = [ - unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) - ] - latents = torch.stack(latents) - # Scale the latents back to the original scale - latents = latents * self.latent_std + self.latent_mean # Decode the latents - image = self.autoencoder.decode(latents).sample - image = (image / 2 + 0.5).clamp(0, 1) + image = self.decode_image(latent_patches, latent_coords) return image.detach() # (batch*num_images_per_prompt, channel, h, w) From 4e21bc05832be60c1414994917ff47c56bb4ca7c Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 00:28:17 +0000 Subject: [PATCH 16/27] Docs and a subtle timestep bug fix --- diffusion/models/transformer.py | 58 ++++++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 12 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 7cc540f3..a8401ef3 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -20,7 +20,21 @@ def modulate(x, shift, scale): def get_multidimensional_position_embeddings(position_embeddings, coords): - """Position embeddings are shape (D, T, F). Coords are shape (B, S, D).""" + """Extracts position embeddings for a multidimensional sequence given by coordinates. + + Position embeddings are shape (D, T, F). Coords are shape (B, S, D). Position embeddings should be + interpreted as D dimensional embeddings with F features each for a maximum of T timesteps. + Coordinates or `coords` is a batch of size B of sequences of length S with D dimensional integer + coordinates. For example, if D=2, then each of the B, S elements of the sequence would have a 2D + X,Y coordinate. + + Args: + position_embeddings (torch.Tensor): Position embeddings of shape (D, T, F). + coords (torch.Tensor): Coordinates of shape (B, S, D). + + Returns: + torch.Tensor: Sequenced embeddings of shape (B, S, F, D) + """ B, S, D = coords.shape F = position_embeddings.shape[2] coords = coords.reshape(B * S, D) @@ -31,7 +45,19 @@ def get_multidimensional_position_embeddings(position_embeddings, coords): def patchify(latents, patch_size): - """Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size].""" + """Function to extract non-overlapping patches from image-like latents. + + Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size]. + Coordinates of the patches are also returned to allow for unpatching and sequence embedding. + + Args: + latents (torch.Tensor): Latents of shape [B, C, H, W]. + patch_size (int): Size of the patches. + + Returns: + torch.Tensor: Patches of shape [B, num_patches, C * patch_size * patch_size]. + torch.Tensor: Coordinates of the patches. Shape [B, num_patches, 2]. + """ # Assume img is a tensor of shape [B, C, H, W] B, C, H, W = latents.shape assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' @@ -47,7 +73,16 @@ def patchify(latents, patch_size): def unpatchify(patches, coords, patch_size): - """Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W].""" + """Recover an image-like tensor from a sequence of patches and their coordinates. + + Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W]. + Coordinates are used to place the patches in the correct location in the image. + + Args: + patches (torch.Tensor): Patches of shape [num_patches, C * patch_size * patch_size]. + coords (torch.Tensor): Coordinates of the patches. Shape [num_patches, 2]. + patch_size (int): Size of the patches. + """ # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] C = patches.shape[1] // (patch_size * patch_size) # Calculate the height and width of the original image from the coordinates @@ -639,9 +674,11 @@ def update_metric(self, batch, outputs, metric): raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') def make_sampling_timesteps(self, N: int): - timesteps = torch.linspace(1, 0, N) + timesteps = torch.linspace(1, 0, N + 1) timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) - return timesteps + # Make timestep differences + delta_t = timesteps[:-1] - timesteps[1:] + return timesteps[:-1], delta_t @torch.no_grad() def generate(self, @@ -695,8 +732,10 @@ def generate(self, latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) # backward diffusion process - timesteps = self.make_sampling_timesteps(num_inference_steps).to(device) + timesteps, delta_t = self.make_sampling_timesteps(num_inference_steps) + timesteps, delta_t = timesteps.to(device), delta_t.to(device) for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): + print(t, delta_t[i]) latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) # Get the model prediction model_out = self.model(latent_patches_input, @@ -710,13 +749,8 @@ def generate(self, # Do CFG pred_cond, pred_uncond = model_out.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) - # compute the time delta. - if i < len(timesteps) - 1: - delta_t = timesteps[i] - timesteps[(i + 1)] - else: - delta_t = timesteps[i] # Update the latents - latent_patches = latent_patches - pred * delta_t + latent_patches = latent_patches - pred * delta_t[i] # Decode the latents image = self.decode_image(latent_patches, latent_coords) return image.detach() # (batch*num_images_per_prompt, channel, h, w) From 602c0a92288456a5ddd8ac3db9763d24ab523c0c Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 05:45:36 +0000 Subject: [PATCH 17/27] Docs and types for transformer --- diffusion/models/transformer.py | 164 ++++++++++++++++++++++++-------- 1 file changed, 123 insertions(+), 41 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index a8401ef3..c7f2a8bb 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -4,7 +4,7 @@ """Diffusion Transformer model.""" import math -from typing import Optional +from typing import Optional, Tuple import torch import torch.nn as nn @@ -14,12 +14,12 @@ from tqdm.auto import tqdm -def modulate(x, shift, scale): +def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """Modulate the input with the shift and scale.""" return x * (1.0 + scale) + shift -def get_multidimensional_position_embeddings(position_embeddings, coords): +def get_multidimensional_position_embeddings(position_embeddings: torch.Tensor, coords: torch.Tensor) -> torch.Tensor: """Extracts position embeddings for a multidimensional sequence given by coordinates. Position embeddings are shape (D, T, F). Coords are shape (B, S, D). Position embeddings should be @@ -44,7 +44,7 @@ def get_multidimensional_position_embeddings(position_embeddings, coords): return sequenced_embeddings # (B, S, F, D) -def patchify(latents, patch_size): +def patchify(latents: torch.Tensor, patch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: """Function to extract non-overlapping patches from image-like latents. Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size]. @@ -72,7 +72,7 @@ def patchify(latents, patch_size): return patches, coords -def unpatchify(patches, coords, patch_size): +def unpatchify(patches: torch.Tensor, coords: torch.Tensor, patch_size: int) -> torch.Tensor: """Recover an image-like tensor from a sequence of patches and their coordinates. Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W]. @@ -89,7 +89,7 @@ def unpatchify(patches, coords, patch_size): H = coords[:, 0].max() * patch_size + patch_size W = coords[:, 1].max() * patch_size + patch_size # Initialize an empty tensor for the reconstructed image - img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) + img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) # type: ignore # Iterate over the patches and their coordinates for patch, (y, x) in zip(patches, patch_size * coords): # Reshape the patch to [C, patch_size, patch_size] @@ -100,9 +100,19 @@ def unpatchify(patches, coords, patch_size): class ScalarEmbedding(nn.Module): - """Embedding block for scalars.""" + """Embedding block for scalars. - def __init__(self, num_features, sinusoidal_embedding_dim=256): + Embeds a scalar into a vector of size `num_features` using a sinusoidal embedding followed by an MLP. + + Args: + num_features (int): The size of the output vector. + sinusoidal_embedding_dim (int): The size of the intermediate sinusoidal embedding. Default: `256`. + + Returns: + torch.Tensor: The embedded scalar + """ + + def __init__(self, num_features: int, sinusoidal_embedding_dim: int = 256): super().__init__() self.num_features = num_features self.sinusoidal_embedding_dim = sinusoidal_embedding_dim @@ -111,14 +121,13 @@ def __init__(self, num_features, sinusoidal_embedding_dim=256): self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) @staticmethod - def timestep_embedding(timesteps, dim, max_period=10000): + def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000) -> torch.Tensor: """Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. + Args: + timesteps (torch.Tensor): The timesteps to embed. + dim (int): The size of the output embedding. + max_period (int): The maximum period of the sinusoidal embedding. Default: `10000`. """ half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / @@ -129,15 +138,22 @@ def timestep_embedding(timesteps, dim, max_period=10000): embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: sinusoidal_embedding = self.timestep_embedding(x, self.sinusoidal_embedding_dim) return self.mlp(sinusoidal_embedding) class VectorEmbedding(nn.Module): - """Embedding block for vectors.""" + """Embedding block for vectors. + + Embeds vectors via an MLP into a vector of size `num_features`. + + Args: + input_features (int): The size of the input vector. + num_features (int): The size of the output vector. + """ - def __init__(self, input_features, num_features): + def __init__(self, input_features: int, num_features: int): super().__init__() self.input_features = input_features self.num_features = num_features @@ -145,14 +161,20 @@ def __init__(self, input_features, num_features): self.linear_2 = nn.Linear(self.num_features, self.num_features) self.mlp = nn.Sequential(self.linear_1, nn.SiLU(), self.linear_2) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) class PreAttentionBlock(nn.Module): - """Block to compute QKV before attention.""" + """Block to compute QKV before attention. + + Includes QK layernorms and an adaptive layernorms. + + Args: + num_features (int): Number of input features. + """ - def __init__(self, num_features): + def __init__(self, num_features: int): super().__init__() self.num_features = num_features @@ -173,7 +195,7 @@ def __init__(self, num_features): # Init the standard deviation of the weights to 0.02 as is tradition nn.init.normal_(self.qkv.weight, std=0.02) - def forward(self, x, t): + def forward(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Calculate the modulations mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) # Forward, with modulations @@ -186,14 +208,23 @@ def forward(self, x, t): class SelfAttention(nn.Module): - """Standard self attention layer that supports masking.""" + """Standard multihead self attention layer that supports masking. + + Args: + num_features (int): Number of input features. + num_heads (int): Number of attention heads. + """ - def __init__(self, num_features, num_heads): + def __init__(self, num_features: int, num_heads: int): super().__init__() self.num_features = num_features self.num_heads = num_heads - def forward(self, q, k, v, mask=None): + def forward(self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Get the shape of the inputs B, T, C = v.size() # Reshape the query, key, and values for multi-head attention @@ -208,9 +239,16 @@ def forward(self, q, k, v, mask=None): class PostAttentionBlock(nn.Module): - """Block to postprocess V after attention.""" + """Block to postprocess v after attention. + + Includes adaptive layernorms. + + Args: + num_features (int): Number of input features. + expansion_factor (int): Expansion factor for the MLP. Default: `4`. + """ - def __init__(self, num_features, expansion_factor=4): + def __init__(self, num_features: int, expansion_factor: int = 4): super().__init__() self.num_features = num_features self.expansion_factor = expansion_factor @@ -233,7 +271,7 @@ def __init__(self, num_features, expansion_factor=4): # Output MLP self.output_mlp = nn.Sequential(self.linear_1, self.nonlinearity, self.linear_2) - def forward(self, v, x, t): + def forward(self, v: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Forward takes v from self attention and the original sequence x with scalar conditioning t.""" # Calculate the modulations mods = self.adaLN_mlp(t).unsqueeze(1).chunk(4, dim=2) @@ -251,9 +289,19 @@ def forward(self, v, x, t): class MMDiTBlock(nn.Module): - """Transformer block that supports masking, multimodal attention, and adaptive norms.""" + """Transformer block that supports masking, multimodal attention, and adaptive norms. + + Can optionally be the last block in the network, in which case it does not apply post-attention layers to the + conditioning sequence, as those params may not be used. + + Args: + num_features (int): Number of input features. + num_heads (int): Number of attention heads. + expansion_factor (int): Expansion factor for the MLP. Default: `4`. + is_last (bool): Whether this is the last block in the network. Default: `False`. + """ - def __init__(self, num_features, num_heads, expansion_factor=4, is_last=False): + def __init__(self, num_features: int, num_heads: int, expansion_factor: int = 4, is_last: bool = False): super().__init__() self.num_features = num_features self.num_heads = num_heads @@ -269,7 +317,11 @@ def __init__(self, num_features, num_heads, expansion_factor=4, is_last=False): if not self.is_last: self.post_attention_block_2 = PostAttentionBlock(self.num_features, self.expansion_factor) - def forward(self, x1, x2, t, mask=None): + def forward(self, + x1: torch.Tensor, + x2: torch.Tensor, + t: torch.Tensor, + mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: # Pre-attention for the two modalities q1, k1, v1 = self.pre_attention_block_1(x1, t) q2, k2, v2 = self.pre_attention_block_2(x2, t) @@ -290,7 +342,22 @@ def forward(self, x1, x2, t, mask=None): class DiffusionTransformer(nn.Module): - """Transformer model for diffusion.""" + """Transformer model for generic diffusion. + + Supports input and conditioning sequences with different lengths and dimensions. + + Args: + num_features (int): Number of hidden features. + num_heads (int): Number of attention heads. + num_layers (int): Number of transformer layers. + input_features (int): Number of features in the input sequence. Default: `192`. + input_max_sequence_length (int): Maximum sequence length for the input sequence. Default: `1024`. + input_dimension (int): Dimension of the input sequence. Default: `2`. + conditioning_features (int): Number of features in the conditioning sequence. Default: `1024`. + conditioning_max_sequence_length (int): Maximum sequence length for the conditioning sequence. Default: `77`. + conditioning_dimension (int): Dimension of the conditioning sequence. Default: `1`. + expansion_factor (int): Expansion factor for the MLPs. Default: `4`. + """ def __init__(self, num_features: int, @@ -354,25 +421,40 @@ def __init__(self, nn.init.zeros_(self.adaLN_mlp_linear.bias) self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) - def fsdp_wrap_fn(self, module): + def fsdp_wrap_fn(self, module: nn.Module) -> bool: if isinstance(module, MMDiTBlock): return True return False - def activation_checkpointing_fn(self, module): + def activation_checkpointing_fn(self, module: nn.Module) -> bool: if isinstance(module, MMDiTBlock): return True return False def forward(self, - x, - input_coords, - t, - conditioning, - conditioning_coords, - input_mask=None, - conditioning_mask=None, - constant_conditioning=None): + x: torch.Tensor, + input_coords: torch.Tensor, + t: torch.Tensor, + conditioning: torch.Tensor, + conditioning_coords: torch.Tensor, + input_mask: Optional[torch.Tensor] = None, + conditioning_mask: Optional[torch.Tensor] = None, + constant_conditioning: Optional[torch.Tensor] = None) -> torch.Tensor: + """Forward pass through the diffusion transformer. + + Args: + x (torch.Tensor): The input sequence of shape (B, T1, C1). + input_coords (torch.Tensor): The coordinates of the D dimensional input sequence of shape (B, T1, D). + t (torch.Tensor): The scalar timesteps of shape (B, 1). + conditioning (torch.Tensor): The conditioning sequence of shape (B, T2, C2). + conditioning_coords (torch.Tensor): The coordinates of the D dimensional conditioning sequence of shape (B, T2, D). + input_mask (Optional[torch.Tensor]): The mask for the input sequence of shape (B, T1). + conditioning_mask (Optional[torch.Tensor]): The mask for the conditioning sequence of shape (B, T2). + constant_conditioning (Optional[torch.Tensor]): Optional additional constant conditioning (B, num_features). + + Returns: + torch.Tensor: The output sequence of shape (B, T1, C1). + """ # Embed the timestep t = self.timestep_embedding(t) # Optionally add constant conditioning. This assumes it has been embedded already. From a59ee6d6f2cb2d81bcb3a8099d32beee32784369 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 18:48:50 +0000 Subject: [PATCH 18/27] Refactor composer model to be separate from base transformer --- diffusion/models/models.py | 32 ++- diffusion/models/t2i_transformer.py | 406 ++++++++++++++++++++++++++++ diffusion/models/transformer.py | 396 --------------------------- diffusion/train.py | 2 +- 4 files changed, 437 insertions(+), 399 deletions(-) create mode 100644 diffusion/models/t2i_transformer.py diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 0e113d31..c29e987a 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -17,8 +17,9 @@ from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module from diffusion.models.pixel_diffusion import PixelDiffusion from diffusion.models.stable_diffusion import StableDiffusion +from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer -from diffusion.models.transformer import ComposerTextToImageMMDiT, DiffusionTransformer +from diffusion.models.transformer import DiffusionTransformer from diffusion.schedulers.schedulers import ContinuousTimeScheduler from diffusion.schedulers.utils import shift_noise_schedule @@ -516,7 +517,34 @@ def text_to_image_transformer( timestep_std: float = 1.0, timestep_shift: float = 1.0, ): - """Text to image transformer training setup.""" + """Text to image transformer training setup. + + Args: + tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer')``. + text_encoder_names (str, Tuple[str, ...]): HuggingFace name(s) of the text encoder(s) to load. + Default: ``('stabilityai/stable-diffusion-xl-base-1.0/text_encoder')``. + vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'. + autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, + will use the vae from `model_name`. Default `None`. + autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. + num_features (int): Number of features in the transformer. Default: `1152`. + num_heads (int): Number of heads in the transformer. Default: `16`. + num_layers (int): Number of layers in the transformer. Default: `28`. + input_max_sequence_length (int): Maximum sequence length for the input. Default: `1024`. + conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. + conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. + patch_size (int): Patch size for the transformer. Default: `2`. + latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value, + a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `0.0`. + latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value, + a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder + checkpoint. Defaults to `1/0.13025`. + timestep_mean (float): The mean of the timesteps. Default: `0.0`. + timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. + timestep_shift (float): The shift of the timesteps. Default: `1.0`. + """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) if (isinstance(tokenizer_names, tuple) or diff --git a/diffusion/models/t2i_transformer.py b/diffusion/models/t2i_transformer.py new file mode 100644 index 00000000..4e318d58 --- /dev/null +++ b/diffusion/models/t2i_transformer.py @@ -0,0 +1,406 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Composer model for text to image generation with a multimodal transformer.""" + +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from composer.models import ComposerModel +from torchmetrics import MeanSquaredError +from tqdm.auto import tqdm + +from diffusion.models.transformer import DiffusionTransformer, VectorEmbedding + + +def patchify(latents: torch.Tensor, patch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to extract non-overlapping patches from image-like latents. + + Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size]. + Coordinates of the patches are also returned to allow for unpatching and sequence embedding. + + Args: + latents (torch.Tensor): Latents of shape [B, C, H, W]. + patch_size (int): Size of the patches. + + Returns: + torch.Tensor: Patches of shape [B, num_patches, C * patch_size * patch_size]. + torch.Tensor: Coordinates of the patches. Shape [B, num_patches, 2]. + """ + # Assume img is a tensor of shape [B, C, H, W] + B, C, H, W = latents.shape + assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' + # Reshape and permute to get non-overlapping patches + num_H_patches = H // patch_size + num_W_patches = W // patch_size + patches = latents.reshape(B, C, num_H_patches, patch_size, num_W_patches, patch_size) + patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * patch_size * patch_size) + # Generate coordinates for each patch + coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) + coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) + return patches, coords + + +def unpatchify(patches: torch.Tensor, coords: torch.Tensor, patch_size: int) -> torch.Tensor: + """Recover an image-like tensor from a sequence of patches and their coordinates. + + Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W]. + Coordinates are used to place the patches in the correct location in the image. + + Args: + patches (torch.Tensor): Patches of shape [num_patches, C * patch_size * patch_size]. + coords (torch.Tensor): Coordinates of the patches. Shape [num_patches, 2]. + patch_size (int): Size of the patches. + """ + # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] + C = patches.shape[1] // (patch_size * patch_size) + # Calculate the height and width of the original image from the coordinates + H = coords[:, 0].max() * patch_size + patch_size + W = coords[:, 1].max() * patch_size + patch_size + # Initialize an empty tensor for the reconstructed image + img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) # type: ignore + # Iterate over the patches and their coordinates + for patch, (y, x) in zip(patches, patch_size * coords): + # Reshape the patch to [C, patch_size, patch_size] + patch = patch.view(C, patch_size, patch_size) + # Place the patch in the corresponding location in the image + img[:, y:y + patch_size, x:x + patch_size] = patch + return img + + +class ComposerTextToImageMMDiT(ComposerModel): + """ComposerModel for text to image with a diffusion transformer. + + Args: + model (DiffusionTransformer): Core diffusion model. + autoencoder (torch.nn.Module): HuggingFace or compatible vae. + must support `.encode()` and `decode()` functions. + text_encoder (torch.nn.Module): HuggingFace CLIP or LLM text enoder. + tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for + text_encoder. For a `CLIPTextModel` this will be the + `CLIPTokenizer` from HuggingFace transformers. + noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the forward diffusion process (training). + inference_noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers + noise scheduler. Used during the backward diffusion process (inference). + prediction_type (str): The type of prediction to use. Currently `epsilon`, `v_prediction` are supported. + latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to + 4 * (0.0,). Default: `None`. + latent_std (Optional[tuple[float]]): The standard deviations of the latent space. If not specified, + defaults to 4 * (1/0.13025,). Default: `None`. + patch_size (int): The size of the patches in the image latents. Default: `2`. + downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. + latent_channels (int): The number of channels in the autoencoder latent space. Default: `4`. + image_key (str): The name of the images in the dataloader batch. Default: `image`. + caption_key (str): The name of the caption in the dataloader batch. Default: `caption`. + caption_mask_key (str): The name of the caption mask in the dataloader batch. Default: `caption_mask`. + """ + + def __init__( + self, + model: DiffusionTransformer, + autoencoder: torch.nn.Module, + text_encoder: torch.nn.Module, + tokenizer, + latent_mean: Optional[tuple[float]] = None, + latent_std: Optional[tuple[float]] = None, + patch_size: int = 2, + downsample_factor: int = 8, + latent_channels: int = 4, + timestep_mean: float = 0.0, + timestep_std: float = 1.0, + timestep_shift: float = 1.0, + image_key: str = 'image', + caption_key: str = 'caption', + caption_mask_key: str = 'caption_mask', + pooled_embedding_features: int = 768, + ): + super().__init__() + self.model = model + self.autoencoder = autoencoder + self.text_encoder = text_encoder + self.tokenizer = tokenizer + if latent_mean is None: + self.latent_mean = 4 * (0.0) + if latent_std is None: + self.latent_std = 4 * (1 / 0.18215,) + self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) + self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) + self.patch_size = patch_size + self.downsample_factor = downsample_factor + self.latent_channels = latent_channels + self.timestep_mean = timestep_mean + self.timestep_std = timestep_std + self.timestep_shift = timestep_shift + self.image_key = image_key + self.caption_key = caption_key + self.caption_mask_key = caption_mask_key + self.pooled_embedding_features = pooled_embedding_features + + # Embedding MLP for the pooled text embeddings + self.pooled_embedding_mlp = VectorEmbedding(pooled_embedding_features, model.num_features) + + # freeze text_encoder during diffusion training and use half precision + self.autoencoder.requires_grad_(False) + self.text_encoder.requires_grad_(False) + self.autoencoder = self.autoencoder.half() + self.text_encoder = self.text_encoder.half() + + # Only FSDP wrap models we are training + self.model._fsdp_wrap = True + self.autoencoder._fsdp_wrap = False + self.text_encoder._fsdp_wrap = True + + # Param counts relevant for MFU computation + # First calc the AdaLN params separately + self.adaLN_params = sum(p.numel() for n, p in self.model.named_parameters() if 'adaLN_mlp_linear' in n) + # For MFU calc we must be careful to prevent double counting of MMDiT flops. + # Here, count the number of params applied to each sequence element. + # Last block must be handled differently since post attn layers don't run on conditioning sequence + self.n_seq_params_per_block = self.model.num_features**2 * (4 + 2 * self.model.expansion_factor) + self.n_seq_params = self.n_seq_params_per_block * (self.model.num_layers - 1) + self.n_seq_params += 3 * (self.model.num_features**2) + self.n_last_layer_params = self.model.num_features**2 * (1 + 2 * self.model.expansion_factor) + # Params only on the input sequence + self.n_input_params = self.model.input_features * self.model.num_features + # Params only on the conditioning sequence + self.n_cond_params = self.model.conditioning_features * self.model.num_features + + # Set up metrics + self.train_metrics = [MeanSquaredError()] + self.val_metrics = [MeanSquaredError()] + + # Optional rng generator + self.rng_generator: Optional[torch.Generator] = None + + def _apply(self, fn): + super(ComposerTextToImageMMDiT, self)._apply(fn) + self.latent_mean = fn(self.latent_mean) + self.latent_std = fn(self.latent_std) + return self + + def set_rng_generator(self, rng_generator: torch.Generator): + """Sets the rng generator for the model.""" + self.rng_generator = rng_generator + + def flops_per_batch(self, batch): + batch_size = batch[self.image_key].shape[0] + height, width = batch[self.image_key].shape[2:] + input_seq_len = height * width / (self.patch_size**2 * self.downsample_factor**2) + cond_seq_len = batch[self.caption_key].shape[1] + seq_len = input_seq_len + cond_seq_len + # Calulate forward flops on full sequence excluding attention + param_flops = 2 * self.n_seq_params * batch_size * seq_len + # Last block contributes a bit less than other blocks + param_flops += 2 * self.n_last_layer_params * batch_size * input_seq_len + # Include input sequence params (comparatively small) + param_flops += 2 * self.n_input_params * batch_size * input_seq_len + # Include conditioning sequence params (comparatively small) + param_flops += 2 * self.n_cond_params * batch_size * cond_seq_len + # Include flops from adaln + param_flops += 2 * self.adaLN_params * batch_size + # Calculate flops for attention layers + attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size + return 3 * param_flops + 3 * attention_flops + + def encode_image(self, image): + with torch.cuda.amp.autocast(enabled=False): + latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data + # Scale and patchify the latents + latents = (latents - self.latent_mean) / self.latent_std + latent_patches, latent_coords = patchify(latents, self.patch_size) + return latent_patches, latent_coords + + @torch.no_grad() + def decode_image(self, latent_patches, latent_coords): + # Unpatchify the latents + latents = [ + unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) + ] + latents = torch.stack(latents) + # Scale the latents back to the original scale + latents = latents * self.latent_std + self.latent_mean + # Decode the latents + with torch.cuda.amp.autocast(enabled=False): + image = self.autoencoder.decode(latents.half()).sample + image = (image / 2 + 0.5).clamp(0, 1) + return image + + def tokenize_prompts(self, prompts): + tokenized_out = self.tokenizer(prompts, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + return tokenized_out['input_ids'], tokenized_out['attention_mask'] + + def combine_attention_masks(self, attention_masks): + if len(attention_masks.shape) == 2: + return attention_masks + elif len(attention_masks.shape) == 3: + encoder_attention_masks = attention_masks[:, 0] + for i in range(1, attention_masks.shape[1]): + encoder_attention_masks |= attention_masks[:, i] + return encoder_attention_masks + else: + raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_masks.shape}') + + def make_text_embeddings_coords(self, text_embeddings): + text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) + text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) + text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) + return text_embeddings_coords + + def embed_tokenized_prompts(self, tokenized_prompts, attention_masks): + with torch.cuda.amp.autocast(enabled=False): + # Ensure text embeddings are not longer than the model can handle + if tokenized_prompts.shape[1] > self.model.conditioning_max_sequence_length: + tokenized_prompts = tokenized_prompts[:, :self.model.conditioning_max_sequence_length] + text_encoder_out = self.text_encoder(tokenized_prompts, attention_mask=attention_masks) + text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] + text_mask = self.combine_attention_masks(attention_masks) + text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) + # Encode the pooled embeddings + pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) + return text_embeddings, text_embeddings_coords, text_mask, pooled_text_embeddings + + def diffusion_forward_process(self, inputs: torch.Tensor): + """Diffusion forward process using a rectified flow.""" + # First, sample timesteps according to a logit-normal distribution + u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator) + u = self.timestep_mean + self.timestep_std * u + timesteps = torch.sigmoid(u).view(-1, 1, 1) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + # Then, add the noise to the latents according to the recitified flow + noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) + noised_inputs = (1 - timesteps) * inputs + timesteps * noise + # Compute the targets, which are the velocities + targets = noise - inputs + return noised_inputs, targets, timesteps[:, 0, 0] + + def forward(self, batch): + # Get the inputs + image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] + # Get the image latents + latent_patches, latent_coords = self.encode_image(image) + # Get the text embeddings and their coords + text_embeddings, text_embeddings_coords, caption_mask, pooled_text_embeddings = self.embed_tokenized_prompts( + caption, caption_mask) + # Diffusion forward process + noised_inputs, targets, timesteps = self.diffusion_forward_process(latent_patches) + # Forward through the model + model_out = self.model(noised_inputs, + latent_coords, + timesteps, + conditioning=text_embeddings, + conditioning_coords=text_embeddings_coords, + input_mask=None, + conditioning_mask=caption_mask, + constant_conditioning=pooled_text_embeddings) + return {'predictions': model_out, 'targets': targets, 'timesteps': timesteps} + + def loss(self, outputs, batch): + """MSE loss between outputs and targets.""" + loss = F.mse_loss(outputs['predictions'], outputs['targets']) + return loss + + def eval_forward(self, batch, outputs=None): + # Skip this if outputs have already been computed, e.g. during training + if outputs is not None: + return outputs + return self.forward(batch) + + def get_metrics(self, is_train: bool = False): + if is_train: + metrics_dict = {metric.__class__.__name__: metric for metric in self.train_metrics} + else: + metrics_dict = {metric.__class__.__name__: metric for metric in self.val_metrics} + return metrics_dict + + def update_metric(self, batch, outputs, metric): + if isinstance(metric, MeanSquaredError): + metric.update(outputs['predictions'], outputs['targets']) + else: + raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') + + def make_sampling_timesteps(self, N: int): + timesteps = torch.linspace(1, 0, N + 1) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + # Make timestep differences + delta_t = timesteps[:-1] - timesteps[1:] + return timesteps[:-1], delta_t + + @torch.no_grad() + def generate(self, + prompt: list, + negative_prompt: Optional[list] = None, + height: int = 256, + width: int = 256, + guidance_scale: float = 7.0, + rescaled_guidance: Optional[float] = None, + num_inference_steps: int = 50, + num_images_per_prompt: int = 1, + progress_bar: bool = True, + seed: Optional[int] = None): + """Generate from the model.""" + device = next(self.model.parameters()).device + # Create rng for the generation + rng_generator = torch.Generator(device=device) + if seed: + rng_generator = rng_generator.manual_seed(seed) + + # Set default negative prompts to empty string if not provided + if negative_prompt is None: + negative_prompt = ['' for _ in prompt] + # Duplicate the images in the prompt and negative prompt if needed. + prompt = [item for item in prompt for _ in range(num_images_per_prompt)] + negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] + # Tokenize both prompt and negative prompts + prompt_tokens, prompt_mask = self.tokenize_prompts(prompt) + negative_prompt_tokens, negative_prompt_mask = self.tokenize_prompts(negative_prompt) + # Embed the tokenized prompts and negative prompts + text_embeddings, text_embeddings_coords, prompt_mask, pooled_embedding = self.embed_tokenized_prompts( + prompt_tokens, prompt_mask) + neg_text_embeddings, neg_text_embeddings_coords, neg_prompt_mask, pooled_neg_embedding = self.embed_tokenized_prompts( + negative_prompt_tokens, negative_prompt_mask) + + # Generate initial noise + latent_height = height // self.downsample_factor + latent_width = width // self.downsample_factor + latents = torch.randn(text_embeddings.shape[0], + self.latent_channels, + latent_height, + latent_width, + device=device) + latent_patches, latent_coords = patchify(latents, self.patch_size) + + # Set up for CFG + text_embeddings = torch.cat([text_embeddings, neg_text_embeddings], dim=0) + text_embeddings_coords = torch.cat([text_embeddings_coords, neg_text_embeddings_coords], dim=0) + text_embeddings_mask = torch.cat([prompt_mask, neg_prompt_mask], dim=0) + pooled_embedding = torch.cat([pooled_embedding, pooled_neg_embedding], dim=0) + latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) + + # backward diffusion process + timesteps, delta_t = self.make_sampling_timesteps(num_inference_steps) + timesteps, delta_t = timesteps.to(device), delta_t.to(device) + for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): + latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) + # Get the model prediction + model_out = self.model(latent_patches_input, + latent_coords_input, + t.unsqueeze(0), + conditioning=text_embeddings, + conditioning_coords=text_embeddings_coords, + input_mask=None, + conditioning_mask=text_embeddings_mask, + constant_conditioning=pooled_embedding) + # Do CFG + pred_cond, pred_uncond = model_out.chunk(2, dim=0) + pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Update the latents + latent_patches = latent_patches - pred * delta_t[i] + # Decode the latents + image = self.decode_image(latent_patches, latent_coords) + return image.detach() # (batch*num_images_per_prompt, channel, h, w) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index c7f2a8bb..13d1217e 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -9,9 +9,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from composer.models import ComposerModel -from torchmetrics import MeanSquaredError -from tqdm.auto import tqdm def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: @@ -44,61 +41,6 @@ def get_multidimensional_position_embeddings(position_embeddings: torch.Tensor, return sequenced_embeddings # (B, S, F, D) -def patchify(latents: torch.Tensor, patch_size: int) -> Tuple[torch.Tensor, torch.Tensor]: - """Function to extract non-overlapping patches from image-like latents. - - Converts a tensor of shape [B, C, H, W] to patches of shape [B, num_patches, C * patch_size * patch_size]. - Coordinates of the patches are also returned to allow for unpatching and sequence embedding. - - Args: - latents (torch.Tensor): Latents of shape [B, C, H, W]. - patch_size (int): Size of the patches. - - Returns: - torch.Tensor: Patches of shape [B, num_patches, C * patch_size * patch_size]. - torch.Tensor: Coordinates of the patches. Shape [B, num_patches, 2]. - """ - # Assume img is a tensor of shape [B, C, H, W] - B, C, H, W = latents.shape - assert H % patch_size == 0 and W % patch_size == 0, 'Image dimensions must be divisible by patch_size' - # Reshape and permute to get non-overlapping patches - num_H_patches = H // patch_size - num_W_patches = W // patch_size - patches = latents.reshape(B, C, num_H_patches, patch_size, num_W_patches, patch_size) - patches = patches.permute(0, 2, 4, 1, 3, 5).reshape(B, -1, C * patch_size * patch_size) - # Generate coordinates for each patch - coords = torch.tensor([(i, j) for i in range(num_H_patches) for j in range(num_W_patches)]) - coords = coords.unsqueeze(0).expand(B, -1, -1).reshape(B, -1, 2) - return patches, coords - - -def unpatchify(patches: torch.Tensor, coords: torch.Tensor, patch_size: int) -> torch.Tensor: - """Recover an image-like tensor from a sequence of patches and their coordinates. - - Converts a tensor of shape [num_patches, C * patch_size * patch_size] to an image of shape [C, H, W]. - Coordinates are used to place the patches in the correct location in the image. - - Args: - patches (torch.Tensor): Patches of shape [num_patches, C * patch_size * patch_size]. - coords (torch.Tensor): Coordinates of the patches. Shape [num_patches, 2]. - patch_size (int): Size of the patches. - """ - # Assume patches is a tensor of shape [num_patches, C * patch_size * patch_size] - C = patches.shape[1] // (patch_size * patch_size) - # Calculate the height and width of the original image from the coordinates - H = coords[:, 0].max() * patch_size + patch_size - W = coords[:, 1].max() * patch_size + patch_size - # Initialize an empty tensor for the reconstructed image - img = torch.zeros((C, H, W), device=patches.device, dtype=patches.dtype) # type: ignore - # Iterate over the patches and their coordinates - for patch, (y, x) in zip(patches, patch_size * coords): - # Reshape the patch to [C, patch_size, patch_size] - patch = patch.view(C, patch_size, patch_size) - # Place the patch in the corresponding location in the image - img[:, y:y + patch_size, x:x + patch_size] = patch - return img - - class ScalarEmbedding(nn.Module): """Embedding block for scalars. @@ -498,341 +440,3 @@ def forward(self, y = modulate(self.final_norm(y), mods[0], mods[1]) y = self.final_linear(y) return y - - -class ComposerTextToImageMMDiT(ComposerModel): - """ComposerModel for text to image with a diffusion transformer. - - Args: - model (DiffusionTransformer): Core diffusion model. - autoencoder (torch.nn.Module): HuggingFace or compatible vae. - must support `.encode()` and `decode()` functions. - text_encoder (torch.nn.Module): HuggingFace CLIP or LLM text enoder. - tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for - text_encoder. For a `CLIPTextModel` this will be the - `CLIPTokenizer` from HuggingFace transformers. - noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers - noise scheduler. Used during the forward diffusion process (training). - inference_noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers - noise scheduler. Used during the backward diffusion process (inference). - prediction_type (str): The type of prediction to use. Currently `epsilon`, `v_prediction` are supported. - latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to - 4 * (0.0,). Default: `None`. - latent_std (Optional[tuple[float]]): The standard deviations of the latent space. If not specified, - defaults to 4 * (1/0.13025,). Default: `None`. - patch_size (int): The size of the patches in the image latents. Default: `2`. - downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. - latent_channels (int): The number of channels in the autoencoder latent space. Default: `4`. - image_key (str): The name of the images in the dataloader batch. Default: `image`. - caption_key (str): The name of the caption in the dataloader batch. Default: `caption`. - caption_mask_key (str): The name of the caption mask in the dataloader batch. Default: `caption_mask`. - """ - - def __init__( - self, - model: DiffusionTransformer, - autoencoder: torch.nn.Module, - text_encoder: torch.nn.Module, - tokenizer, - latent_mean: Optional[tuple[float]] = None, - latent_std: Optional[tuple[float]] = None, - patch_size: int = 2, - downsample_factor: int = 8, - latent_channels: int = 4, - timestep_mean: float = 0.0, - timestep_std: float = 1.0, - timestep_shift: float = 1.0, - image_key: str = 'image', - caption_key: str = 'caption', - caption_mask_key: str = 'caption_mask', - pooled_embedding_features: int = 768, - ): - super().__init__() - self.model = model - self.autoencoder = autoencoder - self.text_encoder = text_encoder - self.tokenizer = tokenizer - if latent_mean is None: - self.latent_mean = 4 * (0.0) - if latent_std is None: - self.latent_std = 4 * (1 / 0.18215,) - self.latent_mean = torch.tensor(latent_mean).view(1, -1, 1, 1) - self.latent_std = torch.tensor(latent_std).view(1, -1, 1, 1) - self.patch_size = patch_size - self.downsample_factor = downsample_factor - self.latent_channels = latent_channels - self.timestep_mean = timestep_mean - self.timestep_std = timestep_std - self.timestep_shift = timestep_shift - self.image_key = image_key - self.caption_key = caption_key - self.caption_mask_key = caption_mask_key - self.pooled_embedding_features = pooled_embedding_features - - # Embedding MLP for the pooled text embeddings - self.pooled_embedding_mlp = VectorEmbedding(pooled_embedding_features, model.num_features) - - # freeze text_encoder during diffusion training and use half precision - self.autoencoder.requires_grad_(False) - self.text_encoder.requires_grad_(False) - self.autoencoder = self.autoencoder.half() - self.text_encoder = self.text_encoder.half() - - # Only FSDP wrap models we are training - self.model._fsdp_wrap = True - self.autoencoder._fsdp_wrap = False - self.text_encoder._fsdp_wrap = True - - # Param counts relevant for MFU computation - # First calc the AdaLN params separately - self.adaLN_params = sum(p.numel() for n, p in self.model.named_parameters() if 'adaLN_mlp_linear' in n) - # For MFU calc we must be careful to prevent double counting of MMDiT flops. - # Here, count the number of params applied to each sequence element. - # Last block must be handled differently since post attn layers don't run on conditioning sequence - self.n_seq_params_per_block = self.model.num_features**2 * (4 + 2 * self.model.expansion_factor) - self.n_seq_params = self.n_seq_params_per_block * (self.model.num_layers - 1) - self.n_seq_params += 3 * (self.model.num_features**2) - self.n_last_layer_params = self.model.num_features**2 * (1 + 2 * self.model.expansion_factor) - # Params only on the input sequence - self.n_input_params = self.model.input_features * self.model.num_features - # Params only on the conditioning sequence - self.n_cond_params = self.model.conditioning_features * self.model.num_features - - # Set up metrics - self.train_metrics = [MeanSquaredError()] - self.val_metrics = [MeanSquaredError()] - - # Optional rng generator - self.rng_generator: Optional[torch.Generator] = None - - def _apply(self, fn): - super(ComposerTextToImageMMDiT, self)._apply(fn) - self.latent_mean = fn(self.latent_mean) - self.latent_std = fn(self.latent_std) - return self - - def set_rng_generator(self, rng_generator: torch.Generator): - """Sets the rng generator for the model.""" - self.rng_generator = rng_generator - - def flops_per_batch(self, batch): - batch_size = batch[self.image_key].shape[0] - height, width = batch[self.image_key].shape[2:] - input_seq_len = height * width / (self.patch_size**2 * self.downsample_factor**2) - cond_seq_len = batch[self.caption_key].shape[1] - seq_len = input_seq_len + cond_seq_len - # Calulate forward flops on full sequence excluding attention - param_flops = 2 * self.n_seq_params * batch_size * seq_len - # Last block contributes a bit less than other blocks - param_flops += 2 * self.n_last_layer_params * batch_size * input_seq_len - # Include input sequence params (comparatively small) - param_flops += 2 * self.n_input_params * batch_size * input_seq_len - # Include conditioning sequence params (comparatively small) - param_flops += 2 * self.n_cond_params * batch_size * cond_seq_len - # Include flops from adaln - param_flops += 2 * self.adaLN_params * batch_size - # Calculate flops for attention layers - attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size - return 3 * param_flops + 3 * attention_flops - - def encode_image(self, image): - with torch.cuda.amp.autocast(enabled=False): - latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data - # Scale and patchify the latents - latents = (latents - self.latent_mean) / self.latent_std - latent_patches, latent_coords = patchify(latents, self.patch_size) - return latent_patches, latent_coords - - @torch.no_grad() - def decode_image(self, latent_patches, latent_coords): - # Unpatchify the latents - latents = [ - unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) - ] - latents = torch.stack(latents) - # Scale the latents back to the original scale - latents = latents * self.latent_std + self.latent_mean - # Decode the latents - with torch.cuda.amp.autocast(enabled=False): - image = self.autoencoder.decode(latents.half()).sample - image = (image / 2 + 0.5).clamp(0, 1) - return image - - def tokenize_prompts(self, prompts): - tokenized_out = self.tokenizer(prompts, - padding='max_length', - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors='pt') - return tokenized_out['input_ids'], tokenized_out['attention_mask'] - - def combine_attention_masks(self, attention_masks): - if len(attention_masks.shape) == 2: - return attention_masks - elif len(attention_masks.shape) == 3: - encoder_attention_masks = attention_masks[:, 0] - for i in range(1, attention_masks.shape[1]): - encoder_attention_masks |= attention_masks[:, i] - return encoder_attention_masks - else: - raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_masks.shape}') - - def make_text_embeddings_coords(self, text_embeddings): - text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) - text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) - text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) - return text_embeddings_coords - - def embed_tokenized_prompts(self, tokenized_prompts, attention_masks): - with torch.cuda.amp.autocast(enabled=False): - # Ensure text embeddings are not longer than the model can handle - if tokenized_prompts.shape[1] > self.model.conditioning_max_sequence_length: - tokenized_prompts = tokenized_prompts[:, :self.model.conditioning_max_sequence_length] - text_encoder_out = self.text_encoder(tokenized_prompts, attention_mask=attention_masks) - text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] - text_mask = self.combine_attention_masks(attention_masks) - text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) - # Encode the pooled embeddings - pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) - return text_embeddings, text_embeddings_coords, text_mask, pooled_text_embeddings - - def diffusion_forward_process(self, inputs: torch.Tensor): - """Diffusion forward process using a rectified flow.""" - # First, sample timesteps according to a logit-normal distribution - u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator) - u = self.timestep_mean + self.timestep_std * u - timesteps = torch.sigmoid(u).view(-1, 1, 1) - timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) - # Then, add the noise to the latents according to the recitified flow - noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) - noised_inputs = (1 - timesteps) * inputs + timesteps * noise - # Compute the targets, which are the velocities - targets = noise - inputs - return noised_inputs, targets, timesteps[:, 0, 0] - - def forward(self, batch): - # Get the inputs - image, caption, caption_mask = batch[self.image_key], batch[self.caption_key], batch[self.caption_mask_key] - # Get the image latents - latent_patches, latent_coords = self.encode_image(image) - # Get the text embeddings and their coords - text_embeddings, text_embeddings_coords, caption_mask, pooled_text_embeddings = self.embed_tokenized_prompts( - caption, caption_mask) - # Diffusion forward process - noised_inputs, targets, timesteps = self.diffusion_forward_process(latent_patches) - # Forward through the model - model_out = self.model(noised_inputs, - latent_coords, - timesteps, - conditioning=text_embeddings, - conditioning_coords=text_embeddings_coords, - input_mask=None, - conditioning_mask=caption_mask, - constant_conditioning=pooled_text_embeddings) - return {'predictions': model_out, 'targets': targets, 'timesteps': timesteps} - - def loss(self, outputs, batch): - """MSE loss between outputs and targets.""" - loss = F.mse_loss(outputs['predictions'], outputs['targets']) - return loss - - def eval_forward(self, batch, outputs=None): - # Skip this if outputs have already been computed, e.g. during training - if outputs is not None: - return outputs - return self.forward(batch) - - def get_metrics(self, is_train: bool = False): - if is_train: - metrics_dict = {metric.__class__.__name__: metric for metric in self.train_metrics} - else: - metrics_dict = {metric.__class__.__name__: metric for metric in self.val_metrics} - return metrics_dict - - def update_metric(self, batch, outputs, metric): - if isinstance(metric, MeanSquaredError): - metric.update(outputs['predictions'], outputs['targets']) - else: - raise ValueError(f'Unrecognized metric {metric.__class__.__name__}') - - def make_sampling_timesteps(self, N: int): - timesteps = torch.linspace(1, 0, N + 1) - timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) - # Make timestep differences - delta_t = timesteps[:-1] - timesteps[1:] - return timesteps[:-1], delta_t - - @torch.no_grad() - def generate(self, - prompt: list, - negative_prompt: Optional[list] = None, - height: int = 256, - width: int = 256, - guidance_scale: float = 7.0, - rescaled_guidance: Optional[float] = None, - num_inference_steps: int = 50, - num_images_per_prompt: int = 1, - progress_bar: bool = True, - seed: Optional[int] = None): - """Generate from the model.""" - device = next(self.model.parameters()).device - # Create rng for the generation - rng_generator = torch.Generator(device=device) - if seed: - rng_generator = rng_generator.manual_seed(seed) - - # Set default negative prompts to empty string if not provided - if negative_prompt is None: - negative_prompt = ['' for _ in prompt] - # Duplicate the images in the prompt and negative prompt if needed. - prompt = [item for item in prompt for _ in range(num_images_per_prompt)] - negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] - # Tokenize both prompt and negative prompts - prompt_tokens, prompt_mask = self.tokenize_prompts(prompt) - negative_prompt_tokens, negative_prompt_mask = self.tokenize_prompts(negative_prompt) - # Embed the tokenized prompts and negative prompts - text_embeddings, text_embeddings_coords, prompt_mask, pooled_embedding = self.embed_tokenized_prompts( - prompt_tokens, prompt_mask) - neg_text_embeddings, neg_text_embeddings_coords, neg_prompt_mask, pooled_neg_embedding = self.embed_tokenized_prompts( - negative_prompt_tokens, negative_prompt_mask) - - # Generate initial noise - latent_height = height // self.downsample_factor - latent_width = width // self.downsample_factor - latents = torch.randn(text_embeddings.shape[0], - self.latent_channels, - latent_height, - latent_width, - device=device) - latent_patches, latent_coords = patchify(latents, self.patch_size) - - # Set up for CFG - text_embeddings = torch.cat([text_embeddings, neg_text_embeddings], dim=0) - text_embeddings_coords = torch.cat([text_embeddings_coords, neg_text_embeddings_coords], dim=0) - text_embeddings_mask = torch.cat([prompt_mask, neg_prompt_mask], dim=0) - pooled_embedding = torch.cat([pooled_embedding, pooled_neg_embedding], dim=0) - latent_coords_input = torch.cat([latent_coords, latent_coords], dim=0) - - # backward diffusion process - timesteps, delta_t = self.make_sampling_timesteps(num_inference_steps) - timesteps, delta_t = timesteps.to(device), delta_t.to(device) - for i, t in tqdm(enumerate(timesteps), disable=not progress_bar): - print(t, delta_t[i]) - latent_patches_input = torch.cat([latent_patches, latent_patches], dim=0) - # Get the model prediction - model_out = self.model(latent_patches_input, - latent_coords_input, - t.unsqueeze(0), - conditioning=text_embeddings, - conditioning_coords=text_embeddings_coords, - input_mask=None, - conditioning_mask=text_embeddings_mask, - constant_conditioning=pooled_embedding) - # Do CFG - pred_cond, pred_uncond = model_out.chunk(2, dim=0) - pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) - # Update the latents - latent_patches = latent_patches - pred * delta_t[i] - # Decode the latents - image = self.decode_image(latent_patches, latent_coords) - return image.detach() # (batch*num_images_per_prompt, channel, h, w) diff --git a/diffusion/train.py b/diffusion/train.py index 33766c3b..9adeb296 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -20,7 +20,7 @@ from torch.optim import Optimizer from diffusion.models.autoencoder import ComposerAutoEncoder, ComposerDiffusersAutoEncoder -from diffusion.models.transformer import ComposerTextToImageMMDiT +from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT def make_autoencoder_optimizer(config: DictConfig, model: ComposerModel) -> Optimizer: From 5db34d4c527838cb3c16ac35506e10239707d37c Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 21:17:26 +0000 Subject: [PATCH 19/27] Docs and types for composer model --- diffusion/models/t2i_transformer.py | 61 ++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 15 deletions(-) diff --git a/diffusion/models/t2i_transformer.py b/diffusion/models/t2i_transformer.py index 4e318d58..12924bb8 100644 --- a/diffusion/models/t2i_transformer.py +++ b/diffusion/models/t2i_transformer.py @@ -3,7 +3,7 @@ """Composer model for text to image generation with a multimodal transformer.""" -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -80,11 +80,6 @@ class ComposerTextToImageMMDiT(ComposerModel): tokenizer (transformers.PreTrainedTokenizer): Tokenizer used for text_encoder. For a `CLIPTextModel` this will be the `CLIPTokenizer` from HuggingFace transformers. - noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers - noise scheduler. Used during the forward diffusion process (training). - inference_noise_scheduler (diffusers.SchedulerMixin): HuggingFace diffusers - noise scheduler. Used during the backward diffusion process (inference). - prediction_type (str): The type of prediction to use. Currently `epsilon`, `v_prediction` are supported. latent_mean (Optional[tuple[float]]): The means of the latent space. If not specified, defaults to 4 * (0.0,). Default: `None`. latent_std (Optional[tuple[float]]): The standard deviations of the latent space. If not specified, @@ -92,9 +87,15 @@ class ComposerTextToImageMMDiT(ComposerModel): patch_size (int): The size of the patches in the image latents. Default: `2`. downsample_factor (int): The factor by which the image is downsampled by the autoencoder. Default `8`. latent_channels (int): The number of channels in the autoencoder latent space. Default: `4`. + timestep_mean (float): The mean of the logit-normal distribution for sampling timesteps. Default: `0.0`. + timestep_std (float): The standard deviation of the logit-normal distribution for sampling timesteps. + Default: `1.0`. + timestep_shift (float): The shift parameter for the logit-normal distribution for sampling timesteps. + A value of `1.0` is no shift. Default: `1.0`. image_key (str): The name of the images in the dataloader batch. Default: `image`. caption_key (str): The name of the caption in the dataloader batch. Default: `caption`. caption_mask_key (str): The name of the caption mask in the dataloader batch. Default: `caption_mask`. + pooled_embedding_features (int): The number of features in the pooled text embeddings. Default: `768`. """ def __init__( @@ -184,7 +185,7 @@ def set_rng_generator(self, rng_generator: torch.Generator): """Sets the rng generator for the model.""" self.rng_generator = rng_generator - def flops_per_batch(self, batch): + def flops_per_batch(self, batch) -> int: batch_size = batch[self.image_key].shape[0] height, width = batch[self.image_key].shape[2:] input_seq_len = height * width / (self.patch_size**2 * self.downsample_factor**2) @@ -204,7 +205,8 @@ def flops_per_batch(self, batch): attention_flops = 4 * self.model.num_layers * seq_len**2 * self.model.num_features * batch_size return 3 * param_flops + 3 * attention_flops - def encode_image(self, image): + def encode_image(self, image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Encode an image tensor with the autoencoder and patchify the latents.""" with torch.cuda.amp.autocast(enabled=False): latents = self.autoencoder.encode(image.half())['latent_dist'].sample().data # Scale and patchify the latents @@ -213,7 +215,8 @@ def encode_image(self, image): return latent_patches, latent_coords @torch.no_grad() - def decode_image(self, latent_patches, latent_coords): + def decode_image(self, latent_patches: torch.Tensor, latent_coords: torch.Tensor) -> torch.Tensor: + """Decode image latent patches and unpatchify the image.""" # Unpatchify the latents latents = [ unpatchify(latent_patches[i], latent_coords[i], self.patch_size) for i in range(latent_patches.shape[0]) @@ -227,7 +230,8 @@ def decode_image(self, latent_patches, latent_coords): image = (image / 2 + 0.5).clamp(0, 1) return image - def tokenize_prompts(self, prompts): + def tokenize_prompts(self, prompts: Union[str, List[str]]) -> Tuple[torch.Tensor, torch.Tensor]: + """Tokenize the prompts using the model's tokenizer.""" tokenized_out = self.tokenizer(prompts, padding='max_length', max_length=self.tokenizer.model_max_length, @@ -235,7 +239,8 @@ def tokenize_prompts(self, prompts): return_tensors='pt') return tokenized_out['input_ids'], tokenized_out['attention_mask'] - def combine_attention_masks(self, attention_masks): + def combine_attention_masks(self, attention_masks: torch.Tensor) -> torch.Tensor: + """Combine attention masks for the encoder if there are multiple text encoders.""" if len(attention_masks.shape) == 2: return attention_masks elif len(attention_masks.shape) == 3: @@ -246,13 +251,17 @@ def combine_attention_masks(self, attention_masks): else: raise ValueError(f'attention_mask should have either 2 or 3 dimensions: {attention_masks.shape}') - def make_text_embeddings_coords(self, text_embeddings): + def make_text_embeddings_coords(self, text_embeddings: torch.Tensor) -> torch.Tensor: + """Make text embeddings coordinates for the transformer.""" text_embeddings_coords = torch.arange(text_embeddings.shape[1], device=text_embeddings.device) text_embeddings_coords = text_embeddings_coords.unsqueeze(0).expand(text_embeddings.shape[0], -1) text_embeddings_coords = text_embeddings_coords.unsqueeze(-1) return text_embeddings_coords - def embed_tokenized_prompts(self, tokenized_prompts, attention_masks): + def embed_tokenized_prompts( + self, tokenized_prompts: torch.Tensor, + attention_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Use the model's text encoder to embed tokenized prompts and create pooled text embeddings.""" with torch.cuda.amp.autocast(enabled=False): # Ensure text embeddings are not longer than the model can handle if tokenized_prompts.shape[1] > self.model.conditioning_max_sequence_length: @@ -265,7 +274,7 @@ def embed_tokenized_prompts(self, tokenized_prompts, attention_masks): pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) return text_embeddings, text_embeddings_coords, text_mask, pooled_text_embeddings - def diffusion_forward_process(self, inputs: torch.Tensor): + def diffusion_forward_process(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Diffusion forward process using a rectified flow.""" # First, sample timesteps according to a logit-normal distribution u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator) @@ -343,7 +352,23 @@ def generate(self, num_images_per_prompt: int = 1, progress_bar: bool = True, seed: Optional[int] = None): - """Generate from the model.""" + """Run generation for the model. + + Args: + prompt (list): List of prompts for the generation. + negative_prompt (Optional[list]): List of negative prompts for the generation. Default: `None`. + height (int): Height of the generated images. Default: `256`. + width (int): Width of the generated images. Default: `256`. + guidance_scale (float): Scale for the guidance. Default: `7.0`. + rescaled_guidance (Optional[float]): Rescale the guidance. Default: `None`. + num_inference_steps (int): Number of inference steps. Default: `50`. + num_images_per_prompt (int): Number of images per prompt. Default: `1`. + progress_bar (bool): Whether to show a progress bar. Default: `True`. + seed (Optional[int]): Seed for the generation. Default: `None`. + + Returns: + torch.Tensor: Generated images. Shape [batch*num_images_per_prompt, channel, h, w]. + """ device = next(self.model.parameters()).device # Create rng for the generation rng_generator = torch.Generator(device=device) @@ -399,6 +424,12 @@ def generate(self, # Do CFG pred_cond, pred_uncond = model_out.chunk(2, dim=0) pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond) + # Optionally rescale the classifer free guidance + if rescaled_guidance is not None: + std_pos = torch.std(pred_cond, dim=(1, 2), keepdim=True) + std_cfg = torch.std(pred, dim=(1, 2), keepdim=True) + pred_rescaled = pred * (std_pos / std_cfg) + pred = pred_rescaled * rescaled_guidance + pred * (1 - rescaled_guidance) # Update the latents latent_patches = latent_patches - pred * delta_t[i] # Decode the latents From f2ffcb6b3d349f1458ef37a0905f9ac26d87564e Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 21:22:55 +0000 Subject: [PATCH 20/27] Add dummy pretrained flag --- diffusion/models/models.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index c29e987a..8ccb623b 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -504,8 +504,6 @@ def text_to_image_transformer( vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix', autoencoder_path: Optional[str] = None, autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', - num_features: int = 1152, - num_heads: int = 16, num_layers: int = 28, input_max_sequence_length: int = 1024, conditioning_features: int = 768, @@ -528,9 +526,8 @@ def text_to_image_transformer( autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified, will use the vae from `model_name`. Default `None`. autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. - num_features (int): Number of features in the transformer. Default: `1152`. - num_heads (int): Number of heads in the transformer. Default: `16`. - num_layers (int): Number of layers in the transformer. Default: `28`. + num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by + this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`. input_max_sequence_length (int): Maximum sequence length for the input. Default: `1024`. conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. @@ -544,6 +541,7 @@ def text_to_image_transformer( timestep_mean (float): The mean of the timesteps. Default: `0.0`. timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. timestep_shift (float): The shift of the timesteps. Default: `1.0`. + pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False. """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -589,8 +587,8 @@ def text_to_image_transformer( assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) # Make the transformer model - transformer = DiffusionTransformer(num_features=num_features, - num_heads=num_heads, + transformer = DiffusionTransformer(num_features=64 * num_layers, + num_heads=num_layers, num_layers=num_layers, input_features=autoencoder_channels * (patch_size**2), input_max_sequence_length=input_max_sequence_length, From b8aaa4c9ef8ae454b12da1eccdcd5ff98c69d16a Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 21:26:12 +0000 Subject: [PATCH 21/27] Minor cleanup --- diffusion/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/diffusion/train.py b/diffusion/train.py index 9adeb296..9a9dbaf5 100644 --- a/diffusion/train.py +++ b/diffusion/train.py @@ -56,13 +56,12 @@ def make_transformer_optimizer(config: DictConfig, model: ComposerModel) -> Opti print('Configuring optimizer for transformer') assert isinstance(model, ComposerTextToImageMMDiT) - # Turn off weight decay for the positional embeddings + # Turn off weight decay for biases, norms, and positional embeddings. no_decay = ['bias', 'norm', 'position_embedding'] params_with_no_decay = [] params_with_decay = [] for name, param in model.named_parameters(): if any(nd in name for nd in no_decay): - #print(f'No decay: {name}') params_with_no_decay.append(param) else: params_with_decay.append(param) From 40b5d9fca08d2c959a34222815e6206243bd19db Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 22:54:53 +0000 Subject: [PATCH 22/27] Update tests --- diffusion/models/models.py | 13 +++++++-- diffusion/models/t2i_transformer.py | 24 ++++++++++++---- diffusion/models/transformer.py | 2 ++ tests/test_transformer.py | 43 ++++++++++++++++++++++++++++- 4 files changed, 72 insertions(+), 10 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 8ccb623b..7f30cba6 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -514,6 +514,10 @@ def text_to_image_transformer( timestep_mean: float = 0.0, timestep_std: float = 1.0, timestep_shift: float = 1.0, + image_key: str = 'image', + caption_key: str = 'captions', + caption_mask_key: str = 'attention_mask', + pretrained: bool = False, ): """Text to image transformer training setup. @@ -541,6 +545,9 @@ def text_to_image_transformer( timestep_mean (float): The mean of the timesteps. Default: `0.0`. timestep_std (float): The std. dev. of the timesteps. Default: `1.0`. timestep_shift (float): The shift of the timesteps. Default: `1.0`. + image_key (str): The key for the image in the batch. Default: `image`. + caption_key (str): The key for the captions in the batch. Default: `captions`. + caption_mask_key (str): The key for the caption mask in the batch. Default: `attention_mask`. pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False. """ latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std) @@ -610,9 +617,9 @@ def text_to_image_transformer( timestep_mean=timestep_mean, timestep_std=timestep_std, timestep_shift=timestep_shift, - image_key='image', - caption_key='captions', - caption_mask_key='attention_mask') + image_key=image_key, + caption_key=caption_key, + caption_mask_key=caption_mask_key) if torch.cuda.is_available(): model = DeviceGPU().module_to_device(model) diff --git a/diffusion/models/t2i_transformer.py b/diffusion/models/t2i_transformer.py index 12924bb8..99d3cca5 100644 --- a/diffusion/models/t2i_transformer.py +++ b/diffusion/models/t2i_transformer.py @@ -270,6 +270,9 @@ def embed_tokenized_prompts( text_embeddings, pooled_text_embeddings = text_encoder_out[0], text_encoder_out[1] text_mask = self.combine_attention_masks(attention_masks) text_embeddings_coords = self.make_text_embeddings_coords(text_embeddings) + # Ensure the embeddings are the same dtype as the model + text_embeddings = text_embeddings.to(next(self.model.parameters()).dtype) + pooled_text_embeddings = pooled_text_embeddings.to(next(self.pooled_embedding_mlp.parameters()).dtype) # Encode the pooled embeddings pooled_text_embeddings = self.pooled_embedding_mlp(pooled_text_embeddings) return text_embeddings, text_embeddings_coords, text_mask, pooled_text_embeddings @@ -277,7 +280,7 @@ def embed_tokenized_prompts( def diffusion_forward_process(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Diffusion forward process using a rectified flow.""" # First, sample timesteps according to a logit-normal distribution - u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator) + u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator, dtype=inputs.dtype) u = self.timestep_mean + self.timestep_std * u timesteps = torch.sigmoid(u).view(-1, 1, 1) timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) @@ -342,8 +345,8 @@ def make_sampling_timesteps(self, N: int): @torch.no_grad() def generate(self, - prompt: list, - negative_prompt: Optional[list] = None, + prompt: Union[str, list], + negative_prompt: Optional[Union[str, list]] = None, height: int = 256, width: int = 256, guidance_scale: float = 7.0, @@ -355,8 +358,8 @@ def generate(self, """Run generation for the model. Args: - prompt (list): List of prompts for the generation. - negative_prompt (Optional[list]): List of negative prompts for the generation. Default: `None`. + prompt (str, list): Prompt or prompts for the generation. + negative_prompt (Optional[str, list]): Negative prompt or prompts for the generation. Default: `None`. height (int): Height of the generated images. Default: `256`. width (int): Width of the generated images. Default: `256`. guidance_scale (float): Scale for the guidance. Default: `7.0`. @@ -376,14 +379,23 @@ def generate(self, rng_generator = rng_generator.manual_seed(seed) # Set default negative prompts to empty string if not provided - if negative_prompt is None: + if isinstance(prompt, str): + prompt = [prompt] + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) + elif isinstance(negative_prompt, list): + assert len(negative_prompt) == len(prompt), 'Prompt and negative prompt must have the same length.' + elif negative_prompt is None: negative_prompt = ['' for _ in prompt] # Duplicate the images in the prompt and negative prompt if needed. prompt = [item for item in prompt for _ in range(num_images_per_prompt)] negative_prompt = [item for item in negative_prompt for _ in range(num_images_per_prompt)] # Tokenize both prompt and negative prompts prompt_tokens, prompt_mask = self.tokenize_prompts(prompt) + prompt_tokens, prompt_mask = prompt_tokens.to(device), prompt_mask.to(device) negative_prompt_tokens, negative_prompt_mask = self.tokenize_prompts(negative_prompt) + negative_prompt_tokens, negative_prompt_mask = negative_prompt_tokens.to(device), negative_prompt_mask.to( + device) # Embed the tokenized prompts and negative prompts text_embeddings, text_embeddings_coords, prompt_mask, pooled_embedding = self.embed_tokenized_prompts( prompt_tokens, prompt_mask) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 13d1217e..9a35112a 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -82,6 +82,8 @@ def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 1000 def forward(self, x: torch.Tensor) -> torch.Tensor: sinusoidal_embedding = self.timestep_embedding(x, self.sinusoidal_embedding_dim) + # Ensure embedding is the correct dtype + sinusoidal_embedding = sinusoidal_embedding.to(next(self.parameters()).dtype) return self.mlp(sinusoidal_embedding) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index da70a8a6..024c272b 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -4,7 +4,9 @@ import pytest import torch -from diffusion.models.transformer import get_multidimensional_position_embeddings, patchify, unpatchify +from diffusion.models.models import text_to_image_transformer +from diffusion.models.t2i_transformer import patchify, unpatchify +from diffusion.models.transformer import get_multidimensional_position_embeddings def test_multidimensional_position_embeddings(): @@ -49,3 +51,42 @@ def test_patch_and_unpatch(patch_size, batch_size, C, H, W): # Verify reconstructed image is close to the original for i in range(batch_size): assert torch.allclose(image_recon[i], image[i], atol=1e-6) + + +def test_t2i_transformer_forward(): + # fp16 vae does not run on cpu + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = text_to_image_transformer(num_layers=2) + batch_size = 1 + H = 32 + W = 32 + image = torch.randn(batch_size, 3, H, W, device=device).half() + caption = torch.randint(low=0, high=128, size=( + batch_size, + 77, + ), dtype=torch.long, device=device) + caption_mask = torch.ones_like(caption, dtype=torch.bool, device=device) + batch = {'image': image, 'captions': caption, 'attention_mask': caption_mask} + outputs = model(batch) # model.forward generates the unet output noise or v_pred target. + # Desired output shape + seq_len = H / (8 * 2) * W / (8 * 2) + output_shape = (1, seq_len, 4 * 2 * 2) + assert outputs['predictions'].shape == output_shape + assert outputs['targets'].shape == output_shape + + +@pytest.mark.parametrize('guidance_scale', [0.0, 3.0]) +@pytest.mark.parametrize('negative_prompt', [None, 'so cool']) +def test_t2i_transformer_generate(guidance_scale, negative_prompt): + model = model = text_to_image_transformer(num_layers=2) + output = model.generate( + prompt='a cool doge', + negative_prompt=negative_prompt, + num_inference_steps=1, + num_images_per_prompt=1, + height=32, + width=32, + guidance_scale=guidance_scale, + progress_bar=False, + ) + assert output.shape == (1, 3, 32, 32) From f62f647863b44e3df12524cb2a95be33ca02756b Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 27 Jun 2024 23:15:13 +0000 Subject: [PATCH 23/27] Figure out max input sequence length from image size --- diffusion/models/models.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 7f30cba6..4326cf1e 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -4,6 +4,7 @@ """Constructors for diffusion models.""" import logging +import math from typing import List, Optional, Tuple, Union import torch @@ -505,7 +506,7 @@ def text_to_image_transformer( autoencoder_path: Optional[str] = None, autoencoder_local_path: str = '/tmp/autoencoder_weights.pt', num_layers: int = 28, - input_max_sequence_length: int = 1024, + max_image_side: int = 1280, conditioning_features: int = 768, conditioning_max_sequence_length: int = 77, patch_size: int = 2, @@ -532,7 +533,7 @@ def text_to_image_transformer( autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`. num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`. - input_max_sequence_length (int): Maximum sequence length for the input. Default: `1024`. + max_image_side (int): Maximum side length of the image. Default: `1280`. conditioning_features (int): Number of features in the conditioning transformer. Default: `768`. conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`. patch_size (int): Patch size for the transformer. Default: `2`. @@ -592,7 +593,8 @@ def text_to_image_transformer( if isinstance(latent_std, float): latent_std = (latent_std,) * autoencoder_channels assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple) - + # Figure out the maximum input sequence length + input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size)) # Make the transformer model transformer = DiffusionTransformer(num_features=64 * num_layers, num_heads=num_layers, From 0df0095ce1cb728c21eb1d8b44d9a8c9279e01b1 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Sat, 29 Jun 2024 04:31:25 +0000 Subject: [PATCH 24/27] Equally spaced timesteps during eval --- diffusion/models/t2i_transformer.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/diffusion/models/t2i_transformer.py b/diffusion/models/t2i_transformer.py index 99d3cca5..ab623d16 100644 --- a/diffusion/models/t2i_transformer.py +++ b/diffusion/models/t2i_transformer.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F from composer.models import ComposerModel +from composer.utils import dist from torchmetrics import MeanSquaredError from tqdm.auto import tqdm @@ -279,11 +280,20 @@ def embed_tokenized_prompts( def diffusion_forward_process(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Diffusion forward process using a rectified flow.""" - # First, sample timesteps according to a logit-normal distribution - u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator, dtype=inputs.dtype) - u = self.timestep_mean + self.timestep_std * u - timesteps = torch.sigmoid(u).view(-1, 1, 1) - timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) + if not self.model.training: + # Sample equally spaced timesteps across all devices + global_batch_size = inputs.shape[0] * dist.get_world_size() + global_timesteps = torch.linspace(0, 1, global_batch_size) + # Get this device's subset of all the timesteps + idx_offset = dist.get_global_rank() * inputs.shape[0] + timesteps = global_timesteps[idx_offset:idx_offset + inputs.shape[0]].to(inputs.device) + timesteps = timesteps.view(-1, 1, 1) + else: + # Sample timesteps according to a logit-normal distribution + u = torch.randn(inputs.shape[0], device=inputs.device, generator=self.rng_generator, dtype=inputs.dtype) + u = self.timestep_mean + self.timestep_std * u + timesteps = torch.sigmoid(u).view(-1, 1, 1) + timesteps = self.timestep_shift * timesteps / (1 + (self.timestep_shift - 1) * timesteps) # Then, add the noise to the latents according to the recitified flow noise = torch.randn(*inputs.shape, device=inputs.device, generator=self.rng_generator) noised_inputs = (1 - timesteps) * inputs + timesteps * noise From ab641eb84ca51574fdb09afef89c9e2020ffa157 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 18 Jul 2024 00:52:34 +0000 Subject: [PATCH 25/27] Separate AdaLN and modulation modules --- diffusion/models/transformer.py | 98 ++++++++++++++++++++++----------- 1 file changed, 66 insertions(+), 32 deletions(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 9a35112a..892d12bd 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -41,6 +41,59 @@ def get_multidimensional_position_embeddings(position_embeddings: torch.Tensor, return sequenced_embeddings # (B, S, F, D) +class AdaptiveLayerNorm(nn.Module): + """Adaptive LayerNorm. + + Scales and shifts the output of a LayerNorm using an MLP conditioned on a scalar. + + Args: + num_features (int): Number of input features. + """ + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + # MLP for computing modulations. + # Initialized to zero so modulation acts as identity at initialization. + self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + # LayerNorm + self.layernorm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # Calculate the modulations + mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) + # Apply the modulations + return modulate(self.layernorm(x), mods[0], mods[1]) + + +class ModulationLayer(nn.Module): + """Modulation layer. + + Scales the input by a factor determined by a scalar input. + + Args: + num_features (int): Number of input features. + """ + + def __init__(self, num_features: int): + super().__init__() + self.num_features = num_features + # MLP for computing modulation. + # Initialized to zero so modulation starts off at zero. + self.adaLN_mlp_linear = nn.Linear(self.num_features, self.num_features, bias=True) + nn.init.zeros_(self.adaLN_mlp_linear.weight) + nn.init.zeros_(self.adaLN_mlp_linear.bias) + self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + + def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + # Calculate the modulations + mods = self.adaLN_mlp(t).unsqueeze(1) + return x * mods + + class ScalarEmbedding(nn.Module): """Embedding block for scalars. @@ -121,14 +174,8 @@ class PreAttentionBlock(nn.Module): def __init__(self, num_features: int): super().__init__() self.num_features = num_features - - # AdaLN MLP for pre-attention. Initialized to zero so modulation acts as identity at initialization. - self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features, bias=True) - nn.init.zeros_(self.adaLN_mlp_linear.weight) - nn.init.zeros_(self.adaLN_mlp_linear.bias) - self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) - # Input layernorm - self.input_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + # Adaptive layernorm + self.adaptive_layernorm = AdaptiveLayerNorm(self.num_features) # Linear layer to get q, k, and v self.qkv = nn.Linear(self.num_features, 3 * self.num_features) # QK layernorms. Original MMDiT used RMSNorm here. @@ -140,10 +187,7 @@ def __init__(self, num_features: int): nn.init.normal_(self.qkv.weight, std=0.02) def forward(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - # Calculate the modulations - mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) - # Forward, with modulations - x = modulate(self.input_norm(x), mods[0], mods[1]) + x = self.adaptive_layernorm(x, t) # Calculate the query, key, and values all in one go q, k, v = self.qkv(x).chunk(3, dim=-1) q = self.q_norm(q) @@ -196,15 +240,12 @@ def __init__(self, num_features: int, expansion_factor: int = 4): super().__init__() self.num_features = num_features self.expansion_factor = expansion_factor - # AdaLN MLP for post-attention. Initialized to zero so modulation acts as identity at initialization. - self.adaLN_mlp_linear = nn.Linear(self.num_features, 4 * self.num_features, bias=True) - nn.init.zeros_(self.adaLN_mlp_linear.weight) - nn.init.zeros_(self.adaLN_mlp_linear.bias) - self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) + # Input modulation + self.modulate_v = ModulationLayer(self.num_features) # Linear layer to process v self.linear_v = nn.Linear(self.num_features, self.num_features) # Layernorm for the output - self.output_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + self.output_norm = AdaptiveLayerNorm(self.num_features) # Transformer style MLP layers self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features) self.nonlinearity = nn.GELU(approximate='tanh') @@ -214,20 +255,20 @@ def __init__(self, num_features: int, expansion_factor: int = 4): nn.init.zeros_(self.linear_2.bias) # Output MLP self.output_mlp = nn.Sequential(self.linear_1, self.nonlinearity, self.linear_2) + # Output modulation + self.modulate_output = ModulationLayer(self.num_features) def forward(self, v: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Forward takes v from self attention and the original sequence x with scalar conditioning t.""" - # Calculate the modulations - mods = self.adaLN_mlp(t).unsqueeze(1).chunk(4, dim=2) # Postprocess v with linear + gating modulation - y = mods[0] * self.linear_v(v) + y = self.modulate_v(self.linear_v(v), t) y = x + y # Adaptive layernorm - y = modulate(self.output_norm(y), mods[1], mods[2]) + y = self.output_norm(y, t) # Output MLP y = self.output_mlp(y) # Gating modulation for the output - y = mods[3] * y + y = self.modulate_output(y, t) y = x + y return y @@ -353,17 +394,11 @@ def __init__(self, self.transformer_blocks.append( MMDiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor, is_last=True)) # Output projection layer - self.final_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6) + self.final_norm = AdaptiveLayerNorm(self.num_features) self.final_linear = nn.Linear(self.num_features, self.input_features) # Init the output layer to zero nn.init.zeros_(self.final_linear.weight) nn.init.zeros_(self.final_linear.bias) - # AdaLN MLP for the output layer - self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features) - # Init the modulations to zero. This will ensure the block acts as identity at initialization - nn.init.zeros_(self.adaLN_mlp_linear.weight) - nn.init.zeros_(self.adaLN_mlp_linear.bias) - self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear) def fsdp_wrap_fn(self, module: nn.Module) -> bool: if isinstance(module, MMDiTBlock): @@ -438,7 +473,6 @@ def forward(self, for block in self.transformer_blocks: y, c = block(y, c, t, mask=mask) # Pass through the output layers to get the right number of elements - mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2) - y = modulate(self.final_norm(y), mods[0], mods[1]) + y = self.final_norm(y, t) y = self.final_linear(y) return y From ce55ee0303cfbd67ca31d0b5bee30f98f54ed7bf Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 25 Jul 2024 21:31:58 +0000 Subject: [PATCH 26/27] Fix formatting --- diffusion/models/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffusion/models/transformer.py b/diffusion/models/transformer.py index 892d12bd..453208fd 100644 --- a/diffusion/models/transformer.py +++ b/diffusion/models/transformer.py @@ -77,7 +77,7 @@ class ModulationLayer(nn.Module): Args: num_features (int): Number of input features. """ - + def __init__(self, num_features: int): super().__init__() self.num_features = num_features From 0f0ca90e1e016d43f04fb688f0045e867811cc39 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 25 Jul 2024 23:32:31 +0000 Subject: [PATCH 27/27] Some tests are gpu only --- tests/test_transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index 024c272b..e138f9fb 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -53,6 +53,7 @@ def test_patch_and_unpatch(patch_size, batch_size, C, H, W): assert torch.allclose(image_recon[i], image[i], atol=1e-6) +@pytest.mark.gpu def test_t2i_transformer_forward(): # fp16 vae does not run on cpu device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -75,6 +76,7 @@ def test_t2i_transformer_forward(): assert outputs['targets'].shape == output_shape +@pytest.mark.gpu @pytest.mark.parametrize('guidance_scale', [0.0, 3.0]) @pytest.mark.parametrize('negative_prompt', [None, 'so cool']) def test_t2i_transformer_generate(guidance_scale, negative_prompt): @@ -84,8 +86,8 @@ def test_t2i_transformer_generate(guidance_scale, negative_prompt): negative_prompt=negative_prompt, num_inference_steps=1, num_images_per_prompt=1, - height=32, - width=32, + height=64, + width=64, guidance_scale=guidance_scale, progress_bar=False, )