Skip to content

Commit

Permalink
MMDiT implementation and text-to-image training with rectified flows (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Jul 26, 2024
1 parent 03a0b6a commit ef74f2b
Show file tree
Hide file tree
Showing 8 changed files with 1,277 additions and 4 deletions.
4 changes: 2 additions & 2 deletions diffusion/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

"""Inference endpoint."""

from diffusion.inference.inference_model import StableDiffusionInference, StableDiffusionXLInference
from diffusion.inference.inference_model import ModelInference, StableDiffusionInference, StableDiffusionXLInference

__all__ = ['StableDiffusionInference', 'StableDiffusionXLInference']
__all__ = ['ModelInference', 'StableDiffusionInference', 'StableDiffusionXLInference']
78 changes: 78 additions & 0 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from composer.utils.file_helpers import get_file
from PIL import Image

import diffusion.models
from diffusion.models import stable_diffusion_2, stable_diffusion_xl

# Local checkpoint params
Expand Down Expand Up @@ -225,3 +226,80 @@ def predict(self, model_requests: List[Dict[str, Any]]):
base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
png_images.append(base64_encoded_image)
return png_images


class ModelInference():
"""Generic inference endpoint class for diffusion models with a model.generate() method.
Args:
model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'.
local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'.
strict (bool): Whether to load the model weights strictly. Default: False.
**model_kwargs: Keyword arguments to pass to the model initialization.
"""

def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs):
self.device = torch.cuda.current_device()
model_factory = getattr(diffusion.models, model_name)
model = model_factory(**model_kwargs)

if 'pretrained' in model_kwargs and model_kwargs['pretrained']:
pass
else:
state_dict = torch.load(local_checkpoint_path)
for key in list(state_dict['state']['model'].keys()):
if 'val_metrics.' in key:
del state_dict['state']['model'][key]
model.load_state_dict(state_dict['state']['model'], strict=strict)
model.to(self.device)
self.model = model.eval()

def predict(self, model_requests: List[Dict[str, Any]]):
prompts = []
negative_prompts = []
generate_kwargs = {}

# assumes the same generate_kwargs across all samples
for req in model_requests:
if 'input' not in req:
raise RuntimeError('"input" must be provided to generate call')
inputs = req['input']

# Prompts and negative prompts if available
if isinstance(inputs, str):
prompts.append(inputs)
elif isinstance(inputs, Dict):
if 'prompt' not in inputs:
raise RuntimeError('"prompt" must be provided to generate call if using a dict as input')
prompts.append(inputs['prompt'])
if 'negative_prompt' in inputs:
negative_prompts.append(inputs['negative_prompt'])
else:
raise RuntimeError(f'Input must be of type string or dict, but it is type: {type(inputs)}')

generate_kwargs = req['parameters']

# Check for prompts
if len(prompts) == 0:
raise RuntimeError('No prompts provided, must be either a string or dictionary with "prompt"')

# Check negative prompt length
if len(negative_prompts) == 0:
negative_prompts = None
elif len(prompts) != len(negative_prompts):
raise RuntimeError('There must be the same number of negative prompts as prompts.')

# Generate images
with torch.cuda.amp.autocast(True):
imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu()

# Send as bytes
png_images = []
for i in range(imgs.shape[0]):
img = (imgs[i].permute(1, 2, 0).numpy() * 255).round().astype('uint8')
pil_image = Image.fromarray(img, 'RGB')
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
base64_encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
png_images.append(base64_encoded_image)
return png_images
4 changes: 3 additions & 1 deletion diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""Diffusion models."""

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl)
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
text_to_image_transformer)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
Expand All @@ -19,4 +20,5 @@
'stable_diffusion_2',
'stable_diffusion_xl',
'StableDiffusion',
'text_to_image_transformer',
]
132 changes: 132 additions & 0 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Constructors for diffusion models."""

import logging
import math
from typing import List, Optional, Tuple, Union

import torch
Expand All @@ -17,7 +18,9 @@
from diffusion.models.layers import ClippedAttnProcessor2_0, ClippedXFormersAttnProcessor, zero_module
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.stable_diffusion import StableDiffusion
from diffusion.models.t2i_transformer import ComposerTextToImageMMDiT
from diffusion.models.text_encoder import MultiTextEncoder, MultiTokenizer
from diffusion.models.transformer import DiffusionTransformer
from diffusion.schedulers.schedulers import ContinuousTimeScheduler
from diffusion.schedulers.utils import shift_noise_schedule

Expand Down Expand Up @@ -496,6 +499,135 @@ def stable_diffusion_xl(
return model


def text_to_image_transformer(
tokenizer_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/tokenizer'),
text_encoder_names: Union[str, Tuple[str, ...]] = ('stabilityai/stable-diffusion-xl-base-1.0/text_encoder'),
vae_model_name: str = 'madebyollin/sdxl-vae-fp16-fix',
autoencoder_path: Optional[str] = None,
autoencoder_local_path: str = '/tmp/autoencoder_weights.pt',
num_layers: int = 28,
max_image_side: int = 1280,
conditioning_features: int = 768,
conditioning_max_sequence_length: int = 77,
patch_size: int = 2,
latent_mean: Union[float, Tuple, str] = 0.0,
latent_std: Union[float, Tuple, str] = 7.67754318618,
timestep_mean: float = 0.0,
timestep_std: float = 1.0,
timestep_shift: float = 1.0,
image_key: str = 'image',
caption_key: str = 'captions',
caption_mask_key: str = 'attention_mask',
pretrained: bool = False,
):
"""Text to image transformer training setup.
Args:
tokenizer_names (str, Tuple[str, ...]): HuggingFace name(s) of the tokenizer(s) to load.
Default: ``('stabilityai/stable-diffusion-xl-base-1.0/tokenizer')``.
text_encoder_names (str, Tuple[str, ...]): HuggingFace name(s) of the text encoder(s) to load.
Default: ``('stabilityai/stable-diffusion-xl-base-1.0/text_encoder')``.
vae_model_name (str): Name of the VAE model to load. Defaults to 'madebyollin/sdxl-vae-fp16-fix'.
autoencoder_path (optional, str): Path to autoencoder weights if using custom autoencoder. If not specified,
will use the vae from `model_name`. Default `None`.
autoencoder_local_path (optional, str): Path to autoencoder weights. Default: `/tmp/autoencoder_weights.pt`.
num_layers (int): Number of layers in the transformer. Number of heads and layer width are determined by
this according to `num_features = 64 * num_layers`, and `num_heads = num_layers`. Default: `28`.
max_image_side (int): Maximum side length of the image. Default: `1280`.
conditioning_features (int): Number of features in the conditioning transformer. Default: `768`.
conditioning_max_sequence_length (int): Maximum sequence length for the conditioning transformer. Default: `77`.
patch_size (int): Patch size for the transformer. Default: `2`.
latent_mean (float, Tuple, str): The mean of the autoencoder latents. Either a float for a single value,
a tuple of means, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `0.0`.
latent_std (float, Tuple, str): The std. dev. of the autoencoder latents. Either a float for a single value,
a tuple of std_devs, or or `'latent_statistics'` to try to use the value from the autoencoder
checkpoint. Defaults to `1/0.13025`.
timestep_mean (float): The mean of the timesteps. Default: `0.0`.
timestep_std (float): The std. dev. of the timesteps. Default: `1.0`.
timestep_shift (float): The shift of the timesteps. Default: `1.0`.
image_key (str): The key for the image in the batch. Default: `image`.
caption_key (str): The key for the captions in the batch. Default: `captions`.
caption_mask_key (str): The key for the caption mask in the batch. Default: `attention_mask`.
pretrained (bool): Whether to load pretrained weights. Not used. Defaults to False.
"""
latent_mean, latent_std = _parse_latent_statistics(latent_mean), _parse_latent_statistics(latent_std)

if (isinstance(tokenizer_names, tuple) or
isinstance(text_encoder_names, tuple)) and len(tokenizer_names) != len(text_encoder_names):
raise ValueError('Number of tokenizer_names and text_encoder_names must be equal')

# Make the tokenizer and text encoder
tokenizer = MultiTokenizer(tokenizer_names_or_paths=tokenizer_names)
text_encoder = MultiTextEncoder(model_names=text_encoder_names, encode_latents_in_fp16=True, pretrained_sdxl=False)

precision = torch.float16
# Make the autoencoder
if autoencoder_path is None:
if latent_mean == 'latent_statistics' or latent_std == 'latent_statistics':
raise ValueError('Cannot use tracked latent_statistics when using the pretrained vae.')
downsample_factor = 8
autoencoder_channels = 4
# Use the pretrained vae
try:
vae = AutoencoderKL.from_pretrained(vae_model_name, subfolder='vae', torch_dtype=precision)
except: # for handling SDXL vae fp16 fixed checkpoint
vae = AutoencoderKL.from_pretrained(vae_model_name, torch_dtype=precision)
else:
# Use a custom autoencoder
vae, latent_statistics = load_autoencoder(autoencoder_path, autoencoder_local_path, torch_dtype=precision)
if latent_statistics is None and (latent_mean == 'latent_statistics' or latent_std == 'latent_statistics'):
raise ValueError(
'Must specify latent scale when using a custom autoencoder without tracking latent statistics.')
if isinstance(latent_mean, str) and latent_mean == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_mean = tuple(latent_statistics['latent_channel_means'])
if isinstance(latent_std, str) and latent_std == 'latent_statistics':
assert isinstance(latent_statistics, dict)
latent_std = tuple(latent_statistics['latent_channel_stds'])
downsample_factor = 2**(len(vae.config['channel_multipliers']) - 1)
autoencoder_channels = vae.config['latent_channels']
assert isinstance(vae, torch.nn.Module)
if isinstance(latent_mean, float):
latent_mean = (latent_mean,) * autoencoder_channels
if isinstance(latent_std, float):
latent_std = (latent_std,) * autoencoder_channels
assert isinstance(latent_mean, tuple) and isinstance(latent_std, tuple)
# Figure out the maximum input sequence length
input_max_sequence_length = math.ceil(max_image_side / (downsample_factor * patch_size))
# Make the transformer model
transformer = DiffusionTransformer(num_features=64 * num_layers,
num_heads=num_layers,
num_layers=num_layers,
input_features=autoencoder_channels * (patch_size**2),
input_max_sequence_length=input_max_sequence_length,
input_dimension=2,
conditioning_features=conditioning_features,
conditioning_max_sequence_length=conditioning_max_sequence_length,
conditioning_dimension=1,
expansion_factor=4)
# Make the composer model
model = ComposerTextToImageMMDiT(model=transformer,
autoencoder=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
latent_mean=latent_mean,
latent_std=latent_std,
patch_size=patch_size,
downsample_factor=downsample_factor,
latent_channels=autoencoder_channels,
timestep_mean=timestep_mean,
timestep_std=timestep_std,
timestep_shift=timestep_shift,
image_key=image_key,
caption_key=caption_key,
caption_mask_key=caption_mask_key)

if torch.cuda.is_available():
model = DeviceGPU().module_to_device(model)
return model


def build_autoencoder(input_channels: int = 3,
output_channels: int = 3,
hidden_channels: int = 128,
Expand Down
Loading

0 comments on commit ef74f2b

Please sign in to comment.