Skip to content

Commit

Permalink
use state of the art vision transformer in CLIP - 0.0.18
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 8, 2021
1 parent c32f724 commit e0456ae
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 50 deletions.
43 changes: 4 additions & 39 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,53 +42,18 @@ clip = CLIP(
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
num_visual_tokens = 512,
text_enc_depth = 6,
visual_enc_depth = 6,
text_seq_len = 256,
visual_seq_len = 1024,
text_heads = 8,
visual_heads = 8
)

text = torch.randint(0, 10000, (2, 256))
images = torch.randint(0, 512, (2, 1024))
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
loss.backward()
```

Combine pretrained VAE with CLIP, and train off raw images

```python
import torch
from dalle_pytorch import DiscreteVAE, CLIP

vae = DiscreteVAE(
num_layers = 3,
num_tokens = 2000,
dim = 512,
hidden_dim = 64
)

clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
num_visual_tokens = 512,
text_enc_depth = 6,
visual_enc_depth = 6,
text_seq_len = 256,
visual_seq_len = 1024,
text_heads = 8,
visual_heads = 8,
vae = vae
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
)

text = torch.randint(0, 10000, (2, 256))
images = torch.randn(2, 3, 256, 256) # raw images
images = torch.randn(2, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = clip(text, images, text_mask = mask, return_loss = True)
Expand Down
27 changes: 17 additions & 10 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def generate_images(
images = vae.decode(img_seq)

if exists(clipper):
scores = clipper(text_seq, img_seq, return_loss = False)
scores = clipper(text_seq, images, return_loss = False)
return images, scores

return images
Expand Down Expand Up @@ -158,13 +158,15 @@ def __init__(
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
num_visual_tokens = 512,
text_enc_depth = 6,
visual_enc_depth = 6,
text_seq_len = 256,
visual_seq_len = 1024,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_heads = 8,
visual_image_size = 256,
visual_patch_size = 32,
channels = 3,
vae = None
):
super().__init__()
Expand All @@ -173,8 +175,13 @@ def __init__(
self.text_transformer = Encoder(dim = dim_text, depth = text_enc_depth, heads = text_heads)
self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False)

self.visual_emb = nn.Embedding(num_visual_tokens, dim_image)
self.visual_pos_emb = nn.Embedding(visual_seq_len, dim_image)
assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (visual_image_size // visual_patch_size) ** 2
patch_dim = channels * visual_patch_size ** 2

self.visual_patch_size = visual_patch_size
self.to_visual_embedding = nn.Linear(patch_dim, dim_image)
self.visual_pos_emb = nn.Embedding(num_patches, dim_image)
self.visual_transformer = Encoder(dim = dim_image, depth = visual_enc_depth, heads = visual_heads)
self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False)

Expand All @@ -192,17 +199,17 @@ def forward(
text_mask = None,
return_loss = False
):
b, device = text.shape[0], text.device
b, device, p = text.shape[0], text.device, self.visual_patch_size

if exists(self.vae):
image = self.vae.get_codebook_indices(image)

text_emb = self.text_emb(text)
text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device))


image_emb = self.visual_emb(image)
image_emb += self.visual_pos_emb(torch.arange(image.shape[1], device = device))
image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
image_emb = self.to_visual_embedding(image_patches)
image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device))

enc_text = self.text_transformer(text_emb, mask = text_mask)
enc_image = self.visual_transformer(image_emb)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.16',
version = '0.0.18',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit e0456ae

Please sign in to comment.