Skip to content

Commit

Permalink
complete ranking with clip
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 6, 2021
1 parent d40aecc commit 57965d9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 2 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,25 @@ images = generate_images(
images.shape # (2, 3, 256, 256)
```

To get the similarity scores from your trained Clipper, just do

```python
from dalle_pytorch import generate_images

images, scores = generate_images(
dalle,
vae = vae,
text = text,
mask = mask,
clipper = clip
)

scores.shape # (2,)
images.shape # (2, 3, 256, 256)

# do your topk here, in paper they sampled 512 and chose top 32
```

## Citations

```bibtex
Expand Down
10 changes: 9 additions & 1 deletion dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def generate_images(
model,
vae,
text,
clipper = None,
mask = None,
filter_thres = 0.9,
temperature = 1.
Expand All @@ -55,10 +56,17 @@ def generate_images(
if out.shape[1] <= text_seq_len:
mask = F.pad(mask, (0, 1), value=True)

text_seq = torch.cat((x[:, :1], out[:, :(text_seq_len - 1)]), dim = 1)
img_seq = out[:, -(image_seq_len + 1):-1]
img_seq -= model.num_text_tokens
img_seq.clamp_(min = 0, max = (model.num_image_tokens - 1))
img_seq.clamp_(min = 0, max = (model.num_image_tokens - 1)) # extra insurance - todo: get rid of this at a future date and rely only on masking of logits

images = vae.decode(img_seq)

if exists(clipper):
scores = clipper(text_seq, img_seq, return_loss = False)
return images, scores.diag()

return images

class DiscreteVAE(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.5',
version = '0.0.7',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 57965d9

Please sign in to comment.