diff --git a/README.md b/README.md index 4f1d27fd..a0d41166 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,29 @@ loss = dalle(text, images, mask = mask, return_loss = True) loss.backward() ``` +Combine pretrained VAE with DALL-E, and pass in raw images + +```python +import torch +from dalle_pytorch import DiscreteVAE, DALLE + +vae = DiscreteVAE( + num_tokens = 512 +) + +dalle = DALLE( + dim = 512, + vae = vae +) + +text = torch.randint(0, 10000, (2, 256)) +images = torch.randn(2, 3, 256, 256) # train directly on raw images, VAE converts to proper embeddings +mask = torch.ones_like(text).bool() + +loss = dalle(text, images, return_loss = True) +loss.backward() +``` + ## Citations ```bibtex diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index a7bb9264..f8e4322a 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -45,6 +45,7 @@ def __init__( nn.Conv2d(64, 3, 1) ) + self.num_tokens = num_tokens self.codebook = nn.Embedding(num_tokens, dim) def forward( @@ -140,7 +141,8 @@ def __init__( text_seq_len = 256, image_seq_len = 1024, depth = 6, # should be 64 - heads = 8 + heads = 8, + vae = None ): super().__init__() self.text_emb = nn.Embedding(num_text_tokens, dim) @@ -153,6 +155,8 @@ def __init__( self.image_seq_len = image_seq_len self.total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS + self.vae = vae + self.image_emb = vae.codebook self.transformer = Decoder(dim = dim, depth = depth, heads = heads) self.to_logits = nn.Sequential( @@ -168,10 +172,17 @@ def forward( return_loss = False ): device = text.device + is_raw_image = len(image.shape) == 4 text_emb = self.text_emb(text) text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device)) + if is_raw_image: + assert exists(self.vae), 'VAE must be passed into constructor if you are to train directly on raw images' + image_logits = self.vae(image, return_logits = True) + codebook_indices = image_logits.argmax(dim = 1).flatten(1) + image = codebook_indices + image_emb = self.image_emb(image) image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device)) diff --git a/setup.py b/setup.py index 0672ffa0..8a3af731 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',