diff --git a/README.md b/README.md index 5a320be..c3cc4a9 100644 --- a/README.md +++ b/README.md @@ -23,16 +23,14 @@ from imagen_pytorch import Unet, Imagen # unet for imagen unet1 = Unet( - dim = 128, - text_embed_dim = 512, + dim = 32, cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8) ).cuda() unet2 = Unet( - dim = 128, - text_embed_dim = 512, + dim = 32, cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8) @@ -61,12 +59,19 @@ for i in (1, 2): # do the above for many many many many steps # now you can sample an image based on the text embeddings from the cascading ddpm -images = imagen.sample(text_embeds = text_embeds) # (4, 3, 256, 256) +images = imagen.sample(texts = [ + 'a whale breaching from afar', + 'young girl blowing out candles on her birthday cake', + 'fireworks with blue and green sparkles' +]) + +images.shape # (3, 3, 256, 256) ``` ## Todo -- [ ] use huggingface transformers for T5-small text embeddings, allow for one to set T5-large +- [x] use huggingface transformers for T5-small text embeddings +- [ ] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer) - [ ] separate unet into base unet and SR3 unet - [ ] build whatever efficient unet they came up with - [ ] add the noise level conditioning with the pseudocode in appendix, and figure out what is this sweep they do at inference time diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index b156977..d04d98d 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -1,4 +1,5 @@ import math +from typing import List from tqdm import tqdm from inspect import isfunction from functools import partial, wraps @@ -21,6 +22,8 @@ from resize_right import resize +from imagen_pytorch.t5 import t5_encode_text, T5_SMALL_EMBED_DIM + # constants NAT = 1. / math.log(2.) @@ -693,7 +696,7 @@ def __init__( dim, *, image_embed_dim = None, - text_embed_dim = None, + text_embed_dim = T5_SMALL_EMBED_DIM, cond_dim = None, num_image_tokens = 4, num_time_tokens = 2, @@ -1295,19 +1298,22 @@ def p_losses(self, unet, x_start, times, *, lowres_cond_img = None, text_embeds @eval_decorator def sample( self, - text = None, + texts: List[str] = None, text_mask = None, text_embeds = None, batch_size = 1, cond_scale = 1., stop_at_unet_number = None ): + device = next(self.parameters()).device + + if exists(texts) and not exists(text_embeds) and not self.unconditional: + text_embeds = t5_encode_text(texts) + text_embeds.to(device) + if not self.unconditional: batch_size = text_embeds.shape[0] - if exists(text) and not exists(text_embeds) and not self.unconditional: - assert False, 'needs to be built' - assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' @@ -1347,7 +1353,7 @@ def sample( def forward( self, image, - text = None, + texts: List[str] = None, text_embeds = None, text_mask = None, unet_number = None @@ -1355,7 +1361,7 @@ def forward( assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) unet_index = unet_number - 1 - + unet = self.get_unet(unet_number) target_image_size = self.image_sizes[unet_index] @@ -1369,8 +1375,9 @@ def forward( times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) - if exists(text) and not exists(text_embeds) and not self.unconditional: - assert False, 'not built yet' + if exists(texts) and not exists(text_embeds) and not self.unconditional: + text_embeds = t5_encode_text(texts) + text_embds.to(image.device) assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' diff --git a/imagen_pytorch/t5.py b/imagen_pytorch/t5.py new file mode 100644 index 0000000..fa6e381 --- /dev/null +++ b/imagen_pytorch/t5.py @@ -0,0 +1,41 @@ +import torch +from transformers import T5Tokenizer, T5ForConditionalGeneration + +def exists(val): + return val is not None + +# singleton globals + +MODEL = None +TOKENIZER = None +T5_SMALL_EMBED_DIM = 512 + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = T5Tokenizer.from_pretrained("t5-small") + return TOKENIZER + +def get_t5(): + global MODEL + if not exists(MODEL): + MODEL = T5ForConditionalGeneration.from_pretrained("t5-small") + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# encoding text + +def t5_encode_text(texts): + t5 = get_t5() + tokenizer = get_tokenizer() + + input_ids = tokenizer.batch_encode_plus(texts, return_tensors = "pt", padding = True, truncation = True).input_ids + input_ids = input_ids.to(next(t5.parameters()).device) + + t5.eval() + with torch.no_grad(): + output = t5(input_ids = input_ids, decoder_input_ids = input_ids[:, :1]) # too lazy to figure out how to make it work without decoder inputs + + return output.encoder_last_hidden_state diff --git a/setup.py b/setup.py index 7c79018..739e415 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'imagen-pytorch', packages = find_packages(exclude=[]), - version = '0.0.1', + version = '0.0.3', license='MIT', description = 'Imagen - unprecedented photorealism × deep level of language understanding', author = 'Phil Wang',