Skip to content

Commit

Permalink
Add composer model class for running with precomputed CLIP and T5 tex…
Browse files Browse the repository at this point in the history
…t latents (#171)
  • Loading branch information
coryMosaicML authored Oct 4, 2024
1 parent 4d6e4aa commit afa6c66
Show file tree
Hide file tree
Showing 5 changed files with 886 additions and 24 deletions.
35 changes: 25 additions & 10 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class LogDiffusionImages(Callback):
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
t5_mask_key: (str): key to use for the T5 attention mask in the batch. Default: ``'T5_ATTENTION_MASK'``.
clip_latent_key: (str): key to use for the CLIP latents in the batch. Default: ``'CLIP_LATENTS'``.
clip_mask_key: (str): key to use for the CLIP attention mask in the batch. Default: ``'CLIP_ATTENTION_MASK'``.
clip_pooled_key: (str): key to use for the CLIP pooled in the batch. Default: ``'CLIP_POOLED'``.
cache_dir: (str, optional): path for HF to cache files while downloading model
"""

Expand All @@ -53,6 +58,11 @@ def __init__(self,
use_table: bool = False,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
t5_latent_key: str = 'T5_LATENTS',
t5_mask_key: str = 'T5_ATTENTION_MASK',
clip_latent_key: str = 'CLIP_LATENTS',
clip_mask_key: str = 'CLIP_ATTENTION_MASK',
clip_pooled_key: str = 'CLIP_POOLED',
cache_dir: Optional[str] = '/tmp/hf_files'):
self.prompts = prompts
self.size = (size, size) if isinstance(size, int) else size
Expand All @@ -61,6 +71,11 @@ def __init__(self,
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.t5_latent_key = t5_latent_key
self.t5_mask_key = t5_mask_key
self.clip_latent_key = clip_latent_key
self.clip_mask_key = clip_mask_key
self.clip_pooled_key = clip_pooled_key
self.cache_dir = cache_dir

# Batch prompts
Expand Down Expand Up @@ -120,10 +135,11 @@ def __init__(self,
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_POOLED'] = clip_pooled
latent_batch[self.t5_latent_key] = t5_latents
latent_batch[self.t5_mask_key] = t5_attention_mask
latent_batch[self.clip_latent_key] = clip_latents
latent_batch[self.clip_mask_key] = clip_attention_mask
latent_batch[self.clip_pooled_key] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
Expand All @@ -143,12 +159,11 @@ def eval_start(self, state: State, logger: Logger):
all_gen_images = []
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

pooled_prompt = batch[self.clip_pooled_key].cuda()
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda(),
batch[self.t5_mask_key].cuda(),
batch[self.clip_mask_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
Expand Down
47 changes: 37 additions & 10 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams

log = logging.getLogger(__name__)
Expand All @@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -53,10 +56,12 @@ def __init__(
image_key: str = 'image',
caption_keys: Tuple[str, ...] = ('caption',),
caption_selection_probs: Tuple[float, ...] = (1.0,),
aspect_ratio_bucket_key: Optional[str] = None,
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: torch.dtype = torch.bfloat16,
drop_nans: bool = True,
**streaming_kwargs,
):

Expand All @@ -72,10 +77,14 @@ def __init__(
self.image_key = image_key
self.caption_keys = caption_keys
self.caption_selection_probs = caption_selection_probs
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.text_latent_keys = text_latent_keys
self.text_latent_shapes = text_latent_shapes
self.attention_mask_keys = attention_mask_keys
self.latent_dtype = latent_dtype
self.drop_nans = drop_nans

def __getitem__(self, index):
sample = super().__getitem__(index)
Expand All @@ -90,15 +99,16 @@ def __getitem__(self, index):
out['cond_original_size'] = torch.tensor(img.size)

# Image transforms
if self.crop is not None:
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
else:
crop_top, crop_left = 0, 0
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

if self.transform is not None:
img = self.transform(img)
out['image'] = img
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

# Get the new height and width
if isinstance(img, torch.Tensor):
Expand Down Expand Up @@ -140,6 +150,13 @@ def __getitem__(self, index):
if 'CLIP_LATENTS' in latent_key:
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy()
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1])
if self.drop_nans:
for latent_key, attn_key in zip(self.text_latent_keys, self.attention_mask_keys):
if out[latent_key].isnan().any():
out[latent_key] = torch.zeros_like(out[latent_key])
out[attn_key] = torch.zeros_like(out[attn_key])
if 'CLIP_LATENTS' in latent_key and out['CLIP_POOLED'].isnan().any():
out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED'])
return out


Expand All @@ -160,6 +177,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: str = 'torch.bfloat16',
aspect_ratio_bucket_key: Optional[str] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -178,11 +196,12 @@ def build_streaming_image_caption_latents_dataloader(
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
Default: ``None``.
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
Default: ``'square'``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -192,18 +211,22 @@ def build_streaming_image_caption_latents_dataloader(
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
if crop_type not in ['square', 'random', 'aspect_ratio']:
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
)
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
isinstance(resize_size[0], int)):
raise ValueError(
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
# Check latent dtype
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
Expand All @@ -225,6 +248,9 @@ def build_streaming_image_caption_latents_dataloader(
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand All @@ -242,6 +268,7 @@ def build_streaming_image_caption_latents_dataloader(
image_key=image_key,
caption_keys=caption_keys,
caption_selection_probs=caption_selection_probs,
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
text_latent_keys=text_latent_keys,
text_latent_shapes=text_latent_shapes,
attention_mask_keys=attention_mask_keys,
Expand Down
7 changes: 5 additions & 2 deletions diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""Diffusion models."""

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
text_to_image_transformer)
discrete_pixel_diffusion, precomputed_text_latent_diffusion, stable_diffusion_2,
stable_diffusion_xl, text_to_image_transformer)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
from diffusion.models.stable_diffusion import StableDiffusion

__all__ = [
Expand All @@ -17,8 +18,10 @@
'discrete_pixel_diffusion',
'NoOpModel',
'PixelDiffusion',
'precomputed_text_latent_diffusion',
'stable_diffusion_2',
'stable_diffusion_xl',
'StableDiffusion',
'PrecomputedTextLatentDiffusion',
'text_to_image_transformer',
]
Loading

0 comments on commit afa6c66

Please sign in to comment.