Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add composer model class for running with precomputed CLIP and T5 text latents #171

Merged
merged 28 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
3459613
Initial model class
corystephenson-db Aug 28, 2024
4b5b227
Support truncating embeddings
corystephenson-db Aug 29, 2024
39ed5f7
Truncate before embedding, and add position embeds
corystephenson-db Aug 29, 2024
8ef76e6
Don't need an arg for the loss
corystephenson-db Aug 29, 2024
402b1a5
Prep for inference
corystephenson-db Aug 29, 2024
19cc8fb
Prep for string inputs
corystephenson-db Aug 30, 2024
101c353
Don't need to check for negative prompt existing
corystephenson-db Aug 30, 2024
3402223
Timesteps shall be ints
corystephenson-db Aug 30, 2024
72476c2
Fix off by one
corystephenson-db Aug 30, 2024
363edd1
Add layernorms before sequence concat
corystephenson-db Sep 1, 2024
260eb6c
Changes for running with bf16
corystephenson-db Sep 2, 2024
e30fa2b
Update docstrings and fix types
corystephenson-db Sep 9, 2024
ca94d5b
Drop nans
corystephenson-db Sep 9, 2024
3d7b65e
Clean up names
corystephenson-db Sep 10, 2024
7bc1796
Fix depreciation
corystephenson-db Sep 10, 2024
4c36a06
More name changes
corystephenson-db Sep 11, 2024
040afae
Fixes for running inference
corystephenson-db Sep 11, 2024
9abc15b
Update docstrings
corystephenson-db Sep 11, 2024
ebacb59
Configurable schedulers
corystephenson-db Sep 11, 2024
5ceee3f
Add schedule shifting
corystephenson-db Sep 11, 2024
d950a50
Add option for LoRA
corystephenson-db Sep 11, 2024
40ecb59
Proper tensor timestep
corystephenson-db Sep 14, 2024
479fe54
Add option for pre-bucketed aspect ratio buckets
corystephenson-db Sep 27, 2024
e7fcb59
Fix some missing keys
corystephenson-db Sep 28, 2024
31116f2
Update hardcoded keys in logger and dataset
corystephenson-db Oct 2, 2024
ca21177
Configurable text encoder dtypes
corystephenson-db Oct 2, 2024
caf0d46
Configurable keys for the model class
corystephenson-db Oct 3, 2024
36485c5
Change text encoder default dtype
corystephenson-db Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 25 additions & 10 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class LogDiffusionImages(Callback):
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
t5_mask_key: (str): key to use for the T5 attention mask in the batch. Default: ``'T5_ATTENTION_MASK'``.
clip_latent_key: (str): key to use for the CLIP latents in the batch. Default: ``'CLIP_LATENTS'``.
clip_mask_key: (str): key to use for the CLIP attention mask in the batch. Default: ``'CLIP_ATTENTION_MASK'``.
clip_pooled_key: (str): key to use for the CLIP pooled in the batch. Default: ``'CLIP_POOLED'``.
cache_dir: (str, optional): path for HF to cache files while downloading model
"""

Expand All @@ -53,6 +58,11 @@ def __init__(self,
use_table: bool = False,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
t5_latent_key: str = 'T5_LATENTS',
t5_mask_key: str = 'T5_ATTENTION_MASK',
clip_latent_key: str = 'CLIP_LATENTS',
clip_mask_key: str = 'CLIP_ATTENTION_MASK',
clip_pooled_key: str = 'CLIP_POOLED',
cache_dir: Optional[str] = '/tmp/hf_files'):
self.prompts = prompts
self.size = (size, size) if isinstance(size, int) else size
Expand All @@ -61,6 +71,11 @@ def __init__(self,
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.t5_latent_key = t5_latent_key
self.t5_mask_key = t5_mask_key
self.clip_latent_key = clip_latent_key
self.clip_mask_key = clip_mask_key
self.clip_pooled_key = clip_pooled_key
self.cache_dir = cache_dir

# Batch prompts
Expand Down Expand Up @@ -120,10 +135,11 @@ def __init__(self,
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_POOLED'] = clip_pooled
latent_batch[self.t5_latent_key] = t5_latents
latent_batch[self.t5_mask_key] = t5_attention_mask
latent_batch[self.clip_latent_key] = clip_latents
latent_batch[self.clip_mask_key] = clip_attention_mask
latent_batch[self.clip_pooled_key] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
Expand All @@ -143,12 +159,11 @@ def eval_start(self, state: State, logger: Logger):
all_gen_images = []
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

pooled_prompt = batch[self.clip_pooled_key].cuda()
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda(),
batch[self.t5_mask_key].cuda(),
batch[self.clip_mask_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
Expand Down
47 changes: 37 additions & 10 deletions diffusion/datasets/image_caption_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torch.utils.data import DataLoader
from torchvision import transforms

from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransform, RandomCropSquare
from diffusion.datasets.laion.transforms import (LargestCenterSquare, RandomCropAspectRatioTransform,
RandomCropBucketedAspectRatioTransform, RandomCropSquare)
from diffusion.datasets.utils import make_streams

log = logging.getLogger(__name__)
Expand All @@ -32,6 +33,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -40,6 +42,7 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset):
attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset.
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``.
drop_nans (bool): Whether to treat samples with NaN latents as dropped captions. Default: ``True``.
**streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader
"""

Expand All @@ -53,10 +56,12 @@ def __init__(
image_key: str = 'image',
caption_keys: Tuple[str, ...] = ('caption',),
caption_selection_probs: Tuple[float, ...] = (1.0,),
aspect_ratio_bucket_key: Optional[str] = None,
text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'),
text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: torch.dtype = torch.bfloat16,
drop_nans: bool = True,
**streaming_kwargs,
):

Expand All @@ -72,10 +77,14 @@ def __init__(
self.image_key = image_key
self.caption_keys = caption_keys
self.caption_selection_probs = caption_selection_probs
self.aspect_ratio_bucket_key = aspect_ratio_bucket_key
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
assert self.aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using RandomCropBucketedAspectRatioTransform'
self.text_latent_keys = text_latent_keys
self.text_latent_shapes = text_latent_shapes
self.attention_mask_keys = attention_mask_keys
self.latent_dtype = latent_dtype
self.drop_nans = drop_nans

def __getitem__(self, index):
sample = super().__getitem__(index)
Expand All @@ -90,15 +99,16 @@ def __getitem__(self, index):
out['cond_original_size'] = torch.tensor(img.size)

# Image transforms
if self.crop is not None:
if isinstance(self.crop, RandomCropBucketedAspectRatioTransform):
img, crop_top, crop_left = self.crop(img, sample[self.aspect_ratio_bucket_key])
elif self.crop is not None:
img, crop_top, crop_left = self.crop(img)
else:
crop_top, crop_left = 0, 0
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

if self.transform is not None:
img = self.transform(img)
out['image'] = img
out['cond_crops_coords_top_left'] = torch.tensor([crop_top, crop_left])

# Get the new height and width
if isinstance(img, torch.Tensor):
Expand Down Expand Up @@ -140,6 +150,13 @@ def __getitem__(self, index):
if 'CLIP_LATENTS' in latent_key:
clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy()
out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1])
if self.drop_nans:
for latent_key, attn_key in zip(self.text_latent_keys, self.attention_mask_keys):
if out[latent_key].isnan().any():
out[latent_key] = torch.zeros_like(out[latent_key])
out[attn_key] = torch.zeros_like(out[attn_key])
if 'CLIP_LATENTS' in latent_key and out['CLIP_POOLED'].isnan().any():
out['CLIP_POOLED'] = torch.zeros_like(out['CLIP_POOLED'])
return out


Expand All @@ -160,6 +177,7 @@ def build_streaming_image_caption_latents_dataloader(
text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)),
attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'),
latent_dtype: str = 'torch.bfloat16',
aspect_ratio_bucket_key: Optional[str] = None,
streaming_kwargs: Optional[Dict] = None,
dataloader_kwargs: Optional[Dict] = None,
):
Expand All @@ -178,11 +196,12 @@ def build_streaming_image_caption_latents_dataloader(
``None``, the bucket with the smallest distance to the current sample's aspect ratio is selected.
Default: ``None``.
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio'].
crop_type (str, optional): Type of crop to perform, either ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio'].
Default: ``'square'``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
caption_keys (Tuple[str, ...]): Key(s) associated with captions in the streaming dataset. Default: ``('caption',)``.
caption_selection_probs (Tuple[float, ...]): The probability of selecting each caption key. Default: ``(1.0,)``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset. Default: ``None``.
text_latent_keys (Tuple[str, ...]): Key(s) associated with text latents in the streaming dataset.
Default: ``('T5_LATENTS', 'CLIP_LATENTS')``.
text_latent_shapes (Tuple[Tuple[int, int], ...]): The shape(s) of the text latents in the streaming dataset.
Expand All @@ -192,18 +211,22 @@ def build_streaming_image_caption_latents_dataloader(
Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``.
latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32',
or 'torch.bfloat16'. Default: ``'torch.bfloat16'``.
aspect_ratio_bucket_key (str, optional): Key associated with the aspect ratio bucket in the streaming dataset.
Needed if using ``crop_type='bucketed_aspect_ratio'``. Default: ``None``.
streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``.
dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``.
"""
# Check crop type
if crop_type is not None:
crop_type = crop_type.lower()
if crop_type not in ['square', 'random', 'aspect_ratio']:
raise ValueError(f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", None]')
if crop_type == 'aspect_ratio' and (isinstance(resize_size, int) or isinstance(resize_size[0], int)):
if crop_type not in ['square', 'random', 'aspect_ratio', 'bucketed_aspect_ratio']:
raise ValueError(
'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.')

f'Invalid crop_type: {crop_type}. Must be ["square", "random", "aspect_ratio", "bucketed_aspect_ratio", None]'
)
if crop_type in ['aspect_ratio', 'bucketed_aspect_ratio'] and (isinstance(resize_size, int) or
isinstance(resize_size[0], int)):
raise ValueError(
'If using aspect ratio bucketing, specify aspect ratio buckets in resize_size as a tuple of tuples.')
# Check latent dtype
dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16}
assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}'
Expand All @@ -225,6 +248,9 @@ def build_streaming_image_caption_latents_dataloader(
crop = RandomCropSquare(resize_size)
elif crop_type == 'aspect_ratio':
crop = RandomCropAspectRatioTransform(resize_size, ar_bucket_boundaries) # type: ignore
elif crop_type == 'bucketed_aspect_ratio':
assert aspect_ratio_bucket_key is not None, 'aspect_ratio_bucket_key must be provided when using bucketed_aspect_ratio crop type'
crop = RandomCropBucketedAspectRatioTransform(resize_size) # type: ignore
else:
crop = None

Expand All @@ -242,6 +268,7 @@ def build_streaming_image_caption_latents_dataloader(
image_key=image_key,
caption_keys=caption_keys,
caption_selection_probs=caption_selection_probs,
aspect_ratio_bucket_key=aspect_ratio_bucket_key,
text_latent_keys=text_latent_keys,
text_latent_shapes=text_latent_shapes,
attention_mask_keys=attention_mask_keys,
Expand Down
7 changes: 5 additions & 2 deletions diffusion/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""Diffusion models."""

from diffusion.models.models import (build_autoencoder, build_diffusers_autoencoder, continuous_pixel_diffusion,
discrete_pixel_diffusion, stable_diffusion_2, stable_diffusion_xl,
text_to_image_transformer)
discrete_pixel_diffusion, precomputed_text_latent_diffusion, stable_diffusion_2,
stable_diffusion_xl, text_to_image_transformer)
from diffusion.models.noop import NoOpModel
from diffusion.models.pixel_diffusion import PixelDiffusion
from diffusion.models.precomputed_text_latent_diffusion import PrecomputedTextLatentDiffusion
from diffusion.models.stable_diffusion import StableDiffusion

__all__ = [
Expand All @@ -17,8 +18,10 @@
'discrete_pixel_diffusion',
'NoOpModel',
'PixelDiffusion',
'precomputed_text_latent_diffusion',
'stable_diffusion_2',
'stable_diffusion_xl',
'StableDiffusion',
'PrecomputedTextLatentDiffusion',
'text_to_image_transformer',
]
Loading
Loading