Skip to content

Commit

Permalink
combine DALLE with discrete VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 6, 2021
1 parent 952cd11 commit d474d96
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 2 deletions.
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,29 @@ loss = dalle(text, images, mask = mask, return_loss = True)
loss.backward()
```

Combine pretrained VAE with DALL-E, and pass in raw images

```python
import torch
from dalle_pytorch import DiscreteVAE, DALLE

vae = DiscreteVAE(
num_tokens = 512
)

dalle = DALLE(
dim = 512,
vae = vae
)

text = torch.randint(0, 10000, (2, 256))
images = torch.randn(2, 3, 256, 256) # train directly on raw images, VAE converts to proper embeddings
mask = torch.ones_like(text).bool()

loss = dalle(text, images, return_loss = True)
loss.backward()
```

## Citations

```bibtex
Expand Down
13 changes: 12 additions & 1 deletion dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
nn.Conv2d(64, 3, 1)
)

self.num_tokens = num_tokens
self.codebook = nn.Embedding(num_tokens, dim)

def forward(
Expand Down Expand Up @@ -140,7 +141,8 @@ def __init__(
text_seq_len = 256,
image_seq_len = 1024,
depth = 6, # should be 64
heads = 8
heads = 8,
vae = None
):
super().__init__()
self.text_emb = nn.Embedding(num_text_tokens, dim)
Expand All @@ -153,6 +155,8 @@ def __init__(
self.image_seq_len = image_seq_len
self.total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS

self.vae = vae
self.image_emb = vae.codebook
self.transformer = Decoder(dim = dim, depth = depth, heads = heads)

self.to_logits = nn.Sequential(
Expand All @@ -168,10 +172,17 @@ def forward(
return_loss = False
):
device = text.device
is_raw_image = len(image.shape) == 4

text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))

if is_raw_image:
assert exists(self.vae), 'VAE must be passed into constructor if you are to train directly on raw images'
image_logits = self.vae(image, return_logits = True)
codebook_indices = image_logits.argmax(dim = 1).flatten(1)
image = codebook_indices

image_emb = self.image_emb(image)
image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device))

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d474d96

Please sign in to comment.