Skip to content

Commit

Permalink
complete t5 small integration
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 24, 2022
1 parent 3c2f8fb commit ee79a1c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 16 deletions.
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ from imagen_pytorch import Unet, Imagen
# unet for imagen

unet1 = Unet(
dim = 128,
text_embed_dim = 512,
dim = 32,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()

unet2 = Unet(
dim = 128,
text_embed_dim = 512,
dim = 32,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
Expand Down Expand Up @@ -61,12 +59,19 @@ for i in (1, 2):
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm

images = imagen.sample(text_embeds = text_embeds) # (4, 3, 256, 256)
images = imagen.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
])

images.shape # (3, 3, 256, 256)
```

## Todo

- [ ] use huggingface transformers for T5-small text embeddings, allow for one to set T5-large
- [x] use huggingface transformers for T5-small text embeddings
- [ ] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
- [ ] separate unet into base unet and SR3 unet
- [ ] build whatever efficient unet they came up with
- [ ] add the noise level conditioning with the pseudocode in appendix, and figure out what is this sweep they do at inference time
Expand Down
25 changes: 16 additions & 9 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import List
from tqdm import tqdm
from inspect import isfunction
from functools import partial, wraps
Expand All @@ -21,6 +22,8 @@

from resize_right import resize

from imagen_pytorch.t5 import t5_encode_text, T5_SMALL_EMBED_DIM

# constants

NAT = 1. / math.log(2.)
Expand Down Expand Up @@ -693,7 +696,7 @@ def __init__(
dim,
*,
image_embed_dim = None,
text_embed_dim = None,
text_embed_dim = T5_SMALL_EMBED_DIM,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
Expand Down Expand Up @@ -1295,19 +1298,22 @@ def p_losses(self, unet, x_start, times, *, lowres_cond_img = None, text_embeds
@eval_decorator
def sample(
self,
text = None,
texts: List[str] = None,
text_mask = None,
text_embeds = None,
batch_size = 1,
cond_scale = 1.,
stop_at_unet_number = None
):
device = next(self.parameters()).device

if exists(texts) and not exists(text_embeds) and not self.unconditional:
text_embeds = t5_encode_text(texts)
text_embeds.to(device)

if not self.unconditional:
batch_size = text_embeds.shape[0]

if exists(text) and not exists(text_embeds) and not self.unconditional:
assert False, 'needs to be built'

assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'

Expand Down Expand Up @@ -1347,15 +1353,15 @@ def sample(
def forward(
self,
image,
text = None,
texts: List[str] = None,
text_embeds = None,
text_mask = None,
unet_number = None
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
unet_index = unet_number - 1

unet = self.get_unet(unet_number)

target_image_size = self.image_sizes[unet_index]
Expand All @@ -1369,8 +1375,9 @@ def forward(

times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)

if exists(text) and not exists(text_embeds) and not self.unconditional:
assert False, 'not built yet'
if exists(texts) and not exists(text_embeds) and not self.unconditional:
text_embeds = t5_encode_text(texts)
text_embds.to(image.device)

assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
Expand Down
41 changes: 41 additions & 0 deletions imagen_pytorch/t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration

def exists(val):
return val is not None

# singleton globals

MODEL = None
TOKENIZER = None
T5_SMALL_EMBED_DIM = 512

def get_tokenizer():
global TOKENIZER
if not exists(TOKENIZER):
TOKENIZER = T5Tokenizer.from_pretrained("t5-small")
return TOKENIZER

def get_t5():
global MODEL
if not exists(MODEL):
MODEL = T5ForConditionalGeneration.from_pretrained("t5-small")
if torch.cuda.is_available():
MODEL = MODEL.cuda()

return MODEL

# encoding text

def t5_encode_text(texts):
t5 = get_t5()
tokenizer = get_tokenizer()

input_ids = tokenizer.batch_encode_plus(texts, return_tensors = "pt", padding = True, truncation = True).input_ids
input_ids = input_ids.to(next(t5.parameters()).device)

t5.eval()
with torch.no_grad():
output = t5(input_ids = input_ids, decoder_input_ids = input_ids[:, :1]) # too lazy to figure out how to make it work without decoder inputs

return output.encoder_last_hidden_state
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'imagen-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.1',
version = '0.0.3',
license='MIT',
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
author = 'Phil Wang',
Expand Down

0 comments on commit ee79a1c

Please sign in to comment.