Skip to content

Commit

Permalink
Add latent logger for T5-XXL text encoder (#154)
Browse files Browse the repository at this point in the history
* add latent logger + small fix to Image.Lanczos

* revert PIL change

* fixes?

---------

Co-authored-by: rishab-partha <[email protected]>
  • Loading branch information
rishab-partha and rishab-partha authored Jul 30, 2024
1 parent adebf01 commit 52088bf
Showing 1 changed file with 101 additions and 12 deletions.
113 changes: 101 additions & 12 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

"""Logger for generated images."""

import gc
from math import ceil
from typing import List, Optional, Tuple, Union

import torch
from composer import Callback, Logger, State
from composer.core import TimeUnit, get_precision_context
from torch.nn.parallel import DistributedDataParallel
from transformers import AutoModel, AutoTokenizer, CLIPTextModel


class LogDiffusionImages(Callback):
Expand All @@ -35,6 +37,9 @@ class LogDiffusionImages(Callback):
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
cache_dir: (str, optional): path for HF to cache files while downloading model
"""

def __init__(self,
Expand All @@ -45,14 +50,18 @@ def __init__(self,
guidance_scale: float = 0.0,
rescaled_guidance: Optional[float] = None,
seed: Optional[int] = 1138,
use_table: bool = False):
use_table: bool = False,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
cache_dir: Optional[str] = '/tmp/hf_files'):
self.prompts = prompts
self.size = (size, size) if isinstance(size, int) else size
self.num_inference_steps = num_inference_steps
self.guidance_scale = guidance_scale
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.cache_dir = cache_dir

# Batch prompts
batch_size = len(prompts) if batch_size is None else batch_size
Expand All @@ -62,6 +71,66 @@ def __init__(self,
start, end = i * batch_size, (i + 1) * batch_size
self.batched_prompts.append(prompts[start:end])

if t5_encoder is not None and clip_encoder is None or t5_encoder is None and clip_encoder is not None:
raise ValueError('Cannot specify only one of text encoder and CLIP encoder.')

self.precomputed_latents = False
self.batched_latents = []
if t5_encoder:
self.precomputed_latents = True
t5_tokenizer = AutoTokenizer.from_pretrained(t5_encoder, cache_dir=self.cache_dir, local_files_only=True)
clip_tokenizer = AutoTokenizer.from_pretrained(clip_encoder,
subfolder='tokenizer',
cache_dir=self.cache_dir,
local_files_only=True)

t5_model = AutoModel.from_pretrained(t5_encoder,
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
subfolder='text_encoder',
torch_dtype=torch.float16,
cache_dir=self.cache_dir,
local_files_only=True).cuda().eval()

for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch['T5_LATENTS'] = t5_latents
latent_batch['CLIP_LATENTS'] = clip_latents
latent_batch['ATTENTION_MASK'] = torch.cat([t5_attention_mask, clip_attention_mask], dim=1)
latent_batch['CLIP_POOLED'] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
del clip_model
gc.collect()
torch.cuda.empty_cache()

def eval_start(self, state: State, logger: Logger):
# Get the model object if it has been wrapped by DDP to access the image generation function.
if isinstance(state.model, DistributedDataParallel):
Expand All @@ -72,17 +141,37 @@ def eval_start(self, state: State, logger: Logger):
# Generate images
with get_precision_context(state.precision):
all_gen_images = []
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch['CLIP_POOLED'].cuda()
prompt_mask = batch['ATTENTION_MASK'].cuda()
t5_embeds = model.t5_proj(batch['T5_LATENTS'].cuda())
clip_embeds = model.clip_proj(batch['CLIP_LATENTS'].cuda())
prompt_embeds = torch.cat([t5_embeds, clip_embeds], dim=1)

gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
else:
for batch in self.batched_prompts:
gen_images = model.generate(
prompt=batch, # type: ignore
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
gen_images = torch.cat(all_gen_images)

# Log images to wandb
Expand Down

0 comments on commit 52088bf

Please sign in to comment.