Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMDiT implementation and text-to-image training with rectified flows #155

Merged
merged 27 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
293d30d
Initial transformer
corystephenson-db Jun 15, 2024
a882d67
Initial running composer model
corystephenson-db Jun 16, 2024
307bc12
Configurable model size
corystephenson-db Jun 16, 2024
e2dd53e
Fix calculation of input sequence length
corystephenson-db Jun 16, 2024
3317a09
Forgot downsample factor in sequence length calc
corystephenson-db Jun 16, 2024
312c13e
Need affines for the layernorms
corystephenson-db Jun 17, 2024
36b3432
Turn off weight decay for biases, norms, and position embeddings
corystephenson-db Jun 17, 2024
40b21cc
Refactor and add tests
corystephenson-db Jun 18, 2024
b396fa6
Wrapping, pooled conditioning, flop calc fix
corystephenson-db Jun 19, 2024
36df692
Initial working MMDiT implementation
corystephenson-db Jun 19, 2024
d440142
Simplify mask logic
corystephenson-db Jun 20, 2024
de3d50e
Rectified flows and pooled embeddings
corystephenson-db Jun 24, 2024
2ad5fc2
Prep for inference
corystephenson-db Jun 25, 2024
499e7ee
Pooled embeddings should be zeroed after embedding for cfg
corystephenson-db Jun 26, 2024
d06580a
Use shared functions to reduce error surface
corystephenson-db Jun 26, 2024
4e21bc0
Docs and a subtle timestep bug fix
corystephenson-db Jun 27, 2024
602c0a9
Docs and types for transformer
corystephenson-db Jun 27, 2024
a59ee6d
Refactor composer model to be separate from base transformer
corystephenson-db Jun 27, 2024
5db34d4
Docs and types for composer model
corystephenson-db Jun 27, 2024
f2ffcb6
Add dummy pretrained flag
corystephenson-db Jun 27, 2024
b8aaa4c
Minor cleanup
corystephenson-db Jun 27, 2024
40b5d9f
Update tests
corystephenson-db Jun 27, 2024
f62f647
Figure out max input sequence length from image size
corystephenson-db Jun 27, 2024
0df0095
Equally spaced timesteps during eval
corystephenson-db Jun 29, 2024
ab641eb
Separate AdaLN and modulation modules
corystephenson-db Jul 18, 2024
ce55ee0
Fix formatting
corystephenson-db Jul 25, 2024
0f0ca90
Some tests are gpu only
corystephenson-db Jul 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions diffusion/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

"""Inference endpoint."""

from diffusion.inference.inference_model import StableDiffusionInference, StableDiffusionXLInference
from diffusion.inference.inference_model import ModelInference, StableDiffusionInference, StableDiffusionXLInference

__all__ = ['StableDiffusionInference', 'StableDiffusionXLInference']
__all__ = ['ModelInference', 'StableDiffusionInference', 'StableDiffusionXLInference']
78 changes: 78 additions & 0 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from composer.utils.file_helpers import get_file
from PIL import Image

import diffusion.models
from diffusion.models import stable_diffusion_2, stable_diffusion_xl

# Local checkpoint params
Expand Down Expand Up @@ -225,3 +226,80 @@ def predict(self, model_requests: List[Dict[str, Any]]):
base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
png_images.append(base64_encoded_image)
return png_images


class ModelInference():
"""Generic inference endpoint class for diffusion models with a model.generate() method.

Args:
model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'.
local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'.
strict (bool): Whether to load the model weights strictly. Default: False.
**model_kwargs: Keyword arguments to pass to the model initialization.
"""

def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs):
self.device = torch.cuda.current_device()
model_factory = getattr(diffusion.models, model_name)
model = model_factory(**model_kwargs)

if 'pretrained' in model_kwargs and model_kwargs['pretrained']:
pass
else:
state_dict = torch.load(local_checkpoint_path)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
model.load_state_dict(state_dict['state']['model'], strict=strict)
model.to(self.device)
self.model = model.eval()

def predict(self, model_requests: List[Dict[str, Any]]):
prompts = []
negative_prompts = []
generate_kwargs = {}

# assumes the same generate_kwargs across all samples
for req in model_requests:
if 'input' not in req:
raise RuntimeError('"input" must be provided to generate call')
inputs = req['input']

# Prompts and negative prompts if available
if isinstance(inputs, str):
prompts.append(inputs)
elif isinstance(inputs, Dict):
if 'prompt' not in inputs:
raise RuntimeError('"prompt" must be provided to generate call if using a dict as input')
prompts.append(inputs['prompt'])
if 'negative_prompt' in inputs:
negative_prompts.append(inputs['negative_prompt'])
else:
raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}')

generate_kwargs = req['parameters']

# Check for prompts
if len(prompts) == 0:
raise RuntimeError('No prompts provided, must be either a string or dictionary with "prompt"')

# Check negative prompt length
if len(negative_prompts) == 0:
negative_prompts = None
elif len(prompts) != len(negative_prompts):
raise RuntimeError('There must be the same number of negative prompts as prompts.')

# Generate images
with torch.cuda.amp.autocast(True):
imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu()

# Send as bytes
png_images = []
for i in range(imgs.shape[0]):
img = (imgs[i].permute(1, 2, 0).numpy() * 255).round().astype('uint8')
pil_image = Image.fromarray(img, 'RGB')
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
png_images.append(base64_encoded_image)
return png_images
4 changes: 3 additions & 1 deletion diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Diffusion models."""

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl)
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
text_to_image_transformer)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
Expand All @@ -19,4 +20,5 @@
'stable_diffusion_2',
'stable_diffusion_xl',
'StableDiffusion',
'text_to_image_transformer',
]
132 changes: 132 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Constructors for diffusion models."""

import logging
import math
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -17,7 +18,9 @@
from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
from diffusion.models.transformer import DiffusionTransformer
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
from diffusion.schedulers.utils import shift_noise_schedule

Expand Down Expand Up @@ -496,6 +499,135 @@ def stable_diffusion_xl(
return model


def text_to_image_transformer(
tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'),
text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder'),
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
autoencoder_path: Optional[str] = None,
autoencoder_local_path: str = '/tmp/autoencoder_weights.pt',
num_layers: int = 28,
max_image_side: int = 1280,
conditioning_features: int = 768,
conditioning_max_sequence_length: int = 77,
patch_size: int = 2,
latent_mean: Union[float, Tuple, str] = 0.0,
latent_std: Union[float, Tuple, str] = 7.67754318618,
timestep_mean: float = 0.0,
timestep_std: float = 1.0,
timestep_shift: float = 1.0,
image_key: str = 'image',
caption_key: str = 'captions',
caption_mask_key: str = 'attention_mask',
pretrained: bool = False,
):
"""Text to image transformer training setup.

Args:
tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load.
Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer')``.
text_encoder_names (str, Tuple[str, ...]): HuggingFace name(s) of the text encoder(s) to load.
Default: ``('stabilityai/stable-diffusion-xl-base-1.0/text_encoder')``.
vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'.
autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified,
will use the vae from `model_name`. Default `None`.
autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`.
num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by
this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`.
max_image_side (int): Maximum side length of the image. Default: `1280`.
conditioning_features (int): Number of features in the conditioning transformer. Default: `768`.
conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`.
patch_size (int): Patch size for the transformer. Default: `2`.
latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value,
a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `0.0`.
latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value,
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `1/0.13025`.
timestep_mean (float): The mean of the timesteps. Default: `0.0`.
timestep_std (float): The std. dev. of the timesteps. Default: `1.0`.
timestep_shift (float): The shift of the timesteps. Default: `1.0`.
image_key (str): The key for the image in the batch. Default: `image`.
caption_key (str): The key for the captions in the batch. Default: `captions`.
caption_mask_key (str): The key for the caption mask in the batch. Default: `attention_mask`.
pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

if (isinstance(tokenizer_names, tuple) or
isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names):
raise ValueError('Number of tokenizer_names and text_encoder_names must be equal')

# Make the tokenizer and text encoder
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)
text_encoder = MultiTextEncoder(model_names=text_encoder_names, encode_latents_in_fp16=True, pretrained_sdxl=False)

precision = torch.float16
# Make the autoencoder
if autoencoder_path is None:
if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics':
raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.')
downsample_factor = 8
autoencoder_channels = 4
# Use the pretrained vae
try:
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision)
except: # for handling SDXL vae fp16 fixed checkpoint
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision)
else:
# Use a custom autoencoder
vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision)
if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'):
raise ValueError(
'Must specify latent scale when using a custom autoencoder without tracking latent statistics.')
if isinstance(latent_mean, str) and latent_mean == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_mean = tuple(latent_statistics['latent_channel_means'])
if isinstance(latent_std, str) and latent_std == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_std = tuple(latent_statistics['latent_channel_stds'])
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
autoencoder_channels = vae.config['latent_channels']
assert isinstance(vae, torch.nn.Module)
if isinstance(latent_mean, float):
latent_mean = (latent_mean,) * autoencoder_channels
if isinstance(latent_std, float):
latent_std = (latent_std,) * autoencoder_channels
assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple)
# Figure out the maximum input sequence length
input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size))
# Make the transformer model
transformer = DiffusionTransformer(num_features=64 * num_layers,
num_heads=num_layers,
num_layers=num_layers,
input_features=autoencoder_channels * (patch_size**2),
input_max_sequence_length=input_max_sequence_length,
input_dimension=2,
conditioning_features=conditioning_features,
conditioning_max_sequence_length=conditioning_max_sequence_length,
conditioning_dimension=1,
expansion_factor=4)
# Make the composer model
model = ComposerTextToImageMMDiT(model=transformer,
autoencoder=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
latent_mean=latent_mean,
latent_std=latent_std,
patch_size=patch_size,
downsample_factor=downsample_factor,
latent_channels=autoencoder_channels,
timestep_mean=timestep_mean,
timestep_std=timestep_std,
timestep_shift=timestep_shift,
image_key=image_key,
caption_key=caption_key,
caption_mask_key=caption_mask_key)

if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
return model


def build_autoencoder(input_channels: int = 3,
output_channels: int = 3,
hidden_channels: int = 128,
Expand Down
Loading
Loading