diff --git a/README.md b/README.md index 994b1b06..0db0ff22 100644 --- a/README.md +++ b/README.md @@ -42,53 +42,18 @@ clip = CLIP( dim_image = 512, dim_latent = 512, num_text_tokens = 10000, - num_visual_tokens = 512, text_enc_depth = 6, - visual_enc_depth = 6, text_seq_len = 256, - visual_seq_len = 1024, text_heads = 8, - visual_heads = 8 -) - -text = torch.randint(0, 10000, (2, 256)) -images = torch.randint(0, 512, (2, 1024)) -mask = torch.ones_like(text).bool() - -loss = clip(text, images, text_mask = mask, return_loss = True) -loss.backward() -``` - -Combine pretrained VAE with CLIP, and train off raw images - -```python -import torch -from dalle_pytorch import DiscreteVAE, CLIP - -vae = DiscreteVAE( - num_layers = 3, - num_tokens = 2000, - dim = 512, - hidden_dim = 64 -) - -clip = CLIP( - dim_text = 512, - dim_image = 512, - dim_latent = 512, - num_text_tokens = 10000, num_visual_tokens = 512, - text_enc_depth = 6, visual_enc_depth = 6, - text_seq_len = 256, - visual_seq_len = 1024, - text_heads = 8, - visual_heads = 8, - vae = vae + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8 ) text = torch.randint(0, 10000, (2, 256)) -images = torch.randn(2, 3, 256, 256) # raw images +images = torch.randn(2, 3, 256, 256) mask = torch.ones_like(text).bool() loss = clip(text, images, text_mask = mask, return_loss = True) diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index d5d08f65..46ce2657 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -65,7 +65,7 @@ def generate_images( images = vae.decode(img_seq) if exists(clipper): - scores = clipper(text_seq, img_seq, return_loss = False) + scores = clipper(text_seq, images, return_loss = False) return images, scores return images @@ -158,13 +158,15 @@ def __init__( dim_image = 512, dim_latent = 512, num_text_tokens = 10000, - num_visual_tokens = 512, text_enc_depth = 6, - visual_enc_depth = 6, text_seq_len = 256, - visual_seq_len = 1024, text_heads = 8, + num_visual_tokens = 512, + visual_enc_depth = 6, visual_heads = 8, + visual_image_size = 256, + visual_patch_size = 32, + channels = 3, vae = None ): super().__init__() @@ -173,8 +175,13 @@ def __init__( self.text_transformer = Encoder(dim = dim_text, depth = text_enc_depth, heads = text_heads) self.to_text_latent = nn.Linear(dim_text, dim_latent, bias = False) - self.visual_emb = nn.Embedding(num_visual_tokens, dim_image) - self.visual_pos_emb = nn.Embedding(visual_seq_len, dim_image) + assert visual_image_size % visual_patch_size == 0, 'Image dimensions must be divisible by the patch size.' + num_patches = (visual_image_size // visual_patch_size) ** 2 + patch_dim = channels * visual_patch_size ** 2 + + self.visual_patch_size = visual_patch_size + self.to_visual_embedding = nn.Linear(patch_dim, dim_image) + self.visual_pos_emb = nn.Embedding(num_patches, dim_image) self.visual_transformer = Encoder(dim = dim_image, depth = visual_enc_depth, heads = visual_heads) self.to_visual_latent = nn.Linear(dim_image, dim_latent, bias = False) @@ -192,7 +199,7 @@ def forward( text_mask = None, return_loss = False ): - b, device = text.shape[0], text.device + b, device, p = text.shape[0], text.device, self.visual_patch_size if exists(self.vae): image = self.vae.get_codebook_indices(image) @@ -200,9 +207,9 @@ def forward( text_emb = self.text_emb(text) text_emb += self.text_pos_emb(torch.arange(text.shape[1], device = device)) - - image_emb = self.visual_emb(image) - image_emb += self.visual_pos_emb(torch.arange(image.shape[1], device = device)) + image_patches = rearrange(image, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) + image_emb = self.to_visual_embedding(image_patches) + image_emb += self.visual_pos_emb(torch.arange(image_emb.shape[1], device = device)) enc_text = self.text_transformer(text_emb, mask = text_mask) enc_image = self.visual_transformer(image_emb) diff --git a/setup.py b/setup.py index eaace065..93d7b2eb 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.16', + version = '0.0.18', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',