From aea93ea31fc186b06004825879de1b3bc3c0caa7 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Mon, 26 Aug 2024 21:49:28 +0000 Subject: [PATCH] Make latent dtype configurable --- diffusion/datasets/image_caption_latents.py | 49 +++++++++++++-------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/diffusion/datasets/image_caption_latents.py b/diffusion/datasets/image_caption_latents.py index 199da1dc..fa1158f5 100644 --- a/diffusion/datasets/image_caption_latents.py +++ b/diffusion/datasets/image_caption_latents.py @@ -39,23 +39,25 @@ class StreamingImageCaptionLatentsDataset(StreamingDataset): Default: ``((512, 4096), (77, 768))``. attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. + latent_dtype (torch.dtype): The dtype to cast the text latents to. Default: ``torch.bfloat16``. **streaming_kwargs: Additional arguments to pass in the construction of the StreamingDataloader """ def __init__( - self, - streams: Sequence[Stream], - caption_drop_prob: float = 0.0, - microcond_drop_prob: float = 0.0, - crop: Optional[Callable] = None, - transform: Optional[Callable] = None, - image_key: str = 'image', - caption_keys: Tuple[str, ...] = ('caption',), - caption_selection_probs: Tuple[float, ...] = (1.0,), - text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), - text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), - attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), - **streaming_kwargs, + self, + streams: Sequence[Stream], + caption_drop_prob: float = 0.0, + microcond_drop_prob: float = 0.0, + crop: Optional[Callable] = None, + transform: Optional[Callable] = None, + image_key: str = 'image', + caption_keys: Tuple[str, ...] = ('caption',), + caption_selection_probs: Tuple[float, ...] = (1.0,), + text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), + text_latent_shapes: Tuple[Tuple[int, int], ...] = ((512, 4096), (77, 768)), + attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), + latent_dtype: torch.dtype = torch.bfloat16, + **streaming_kwargs, ): # Set defaults for vision-friendly streaming args. @@ -73,6 +75,7 @@ def __init__( self.text_latent_keys = text_latent_keys self.text_latent_shapes = text_latent_shapes self.attention_mask_keys = attention_mask_keys + self.latent_dtype = latent_dtype def __getitem__(self, index): sample = super().__getitem__(index) @@ -124,18 +127,19 @@ def __getitem__(self, index): attention_key = f'{caption_key}_{self.attention_mask_keys[i]}' if torch.rand(1) < self.caption_drop_prob: - out[self.text_latent_keys[i]] = torch.zeros(latent_shape, dtype=torch.float16) + out[self.text_latent_keys[i]] = torch.zeros(latent_shape, dtype=self.latent_dtype) out[self.attention_mask_keys[i]] = torch.zeros(latent_shape[0]) if 'CLIP_LATENTS' in latent_key: out['CLIP_POOLED'] = torch.zeros(latent_shape[1]) else: - text_latent = np.frombuffer(sample[latent_key], dtype=np.float16).copy() - out[self.text_latent_keys[i]] = torch.from_numpy(text_latent).reshape(latent_shape) + text_latent = np.frombuffer(sample[latent_key], dtype=np.float32).copy() + out[self.text_latent_keys[i]] = torch.from_numpy(text_latent).to( + self.latent_dtype).reshape(latent_shape) attention_mask = np.frombuffer(sample[attention_key], dtype=np.bool_).copy() out[self.attention_mask_keys[i]] = torch.from_numpy(attention_mask).to(dtype=torch.float).reshape(-1) #.reshape(latent_shape[0]) if 'CLIP_LATENTS' in latent_key: - clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float16).copy() - out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).reshape(latent_shape[1]) + clip_pooled = np.frombuffer(sample[f'{caption_key}_CLIP_POOLED_TEXT'], dtype=np.float32).copy() + out['CLIP_POOLED'] = torch.from_numpy(clip_pooled).to(self.latent_dtype).reshape(latent_shape[1]) return out @@ -155,6 +159,7 @@ def build_streaming_image_caption_latents_dataloader( text_latent_keys: Tuple[str, ...] = ('T5_LATENTS', 'CLIP_LATENTS'), text_latent_shapes: Tuple[Tuple, ...] = ((512, 4096), (77, 768)), attention_mask_keys: Tuple[str, ...] = ('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK'), + latent_dtype: str = 'torch.bfloat16', streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -185,6 +190,8 @@ def build_streaming_image_caption_latents_dataloader( Default: ``((512, 4096), (77, 768))``. attention_mask_keys (Tuple[str, ...]): Key(s) associated with attention masks in the streaming dataset. Default: ``('T5_ATTENTION_MASK', 'CLIP_ATTENTION_MASK')``. + latent_dtype (str): The torch dtype to cast the text latents to. One of 'torch.float16', 'torch.float32', + or 'torch.bfloat16'. Default: ``'torch.bfloat16'``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -197,6 +204,11 @@ def build_streaming_image_caption_latents_dataloader( raise ValueError( 'If using crop_type="aspect_ratio", specify aspect ratio buckets in resize_size as a tuple of tuples.') + # Check latent dtype + dtypes = {'torch.float16': torch.float16, 'torch.float32': torch.float32, 'torch.bfloat16': torch.bfloat16} + assert latent_dtype in dtypes, f'Invalid latent_dtype: {latent_dtype}. Must be one of {list(dtypes.keys())}' + dtype = dtypes[latent_dtype] + # Handle ``None`` kwargs if streaming_kwargs is None: streaming_kwargs = {} @@ -233,6 +245,7 @@ def build_streaming_image_caption_latents_dataloader( text_latent_keys=text_latent_keys, text_latent_shapes=text_latent_shapes, attention_mask_keys=attention_mask_keys, + latent_dtype=dtype, **streaming_kwargs, )