From 57965d9eb24e6adbc23fe35a97bd48b8fcaf0bb5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 6 Jan 2021 11:15:19 -0800 Subject: [PATCH] complete ranking with clip --- README.md | 19 +++++++++++++++++++ dalle_pytorch/dalle_pytorch.py | 10 +++++++++- setup.py | 2 +- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 04db5b51..fc430e42 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,25 @@ images = generate_images( images.shape # (2, 3, 256, 256) ``` +To get the similarity scores from your trained Clipper, just do + +```python +from dalle_pytorch import generate_images + +images, scores = generate_images( + dalle, + vae = vae, + text = text, + mask = mask, + clipper = clip +) + +scores.shape # (2,) +images.shape # (2, 3, 256, 256) + +# do your topk here, in paper they sampled 512 and chose top 32 +``` + ## Citations ```bibtex diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index a353208f..4c30c2d7 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -32,6 +32,7 @@ def generate_images( model, vae, text, + clipper = None, mask = None, filter_thres = 0.9, temperature = 1. @@ -55,10 +56,17 @@ def generate_images( if out.shape[1] <= text_seq_len: mask = F.pad(mask, (0, 1), value=True) + text_seq = torch.cat((x[:, :1], out[:, :(text_seq_len - 1)]), dim = 1) img_seq = out[:, -(image_seq_len + 1):-1] img_seq -= model.num_text_tokens - img_seq.clamp_(min = 0, max = (model.num_image_tokens - 1)) + img_seq.clamp_(min = 0, max = (model.num_image_tokens - 1)) # extra insurance - todo: get rid of this at a future date and rely only on masking of logits + images = vae.decode(img_seq) + + if exists(clipper): + scores = clipper(text_seq, img_seq, return_loss = False) + return images, scores.diag() + return images class DiscreteVAE(nn.Module): diff --git a/setup.py b/setup.py index dae46bb0..e5004f41 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.5', + version = '0.0.7', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',