Skip to content

Commit

Permalink
always mask logits to correct modality, even during training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 7, 2021
1 parent 8995720 commit 031e3be
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,15 +297,19 @@ def forward(
tokens = self.text_emb(text)
tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

seq_len = tokens.shape[1]

if exists(image) and not is_empty(image):
is_raw_image = len(image.shape) == 4
image_len = image.shape[1]
seq_len += image_len

if is_raw_image:
assert exists(self.vae), 'VAE must be passed into constructor if you are to train directly on raw images'
image = self.vae.get_codebook_indices(image)

image_emb = self.image_emb(image)
image_emb += self.image_pos_emb(torch.arange(image.shape[1], device = device))
image_emb += self.image_pos_emb(torch.arange(image_len, device = device))

tokens = torch.cat((tokens, image_emb), dim = 1)

Expand All @@ -315,11 +319,12 @@ def forward(
out = self.transformer(tokens, mask = mask)
logits = self.to_logits(out)

# mask logits to make sure text predicts text (except last token), and image predicts image
mask = self.logits_mask[:, :seq_len]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(mask, max_neg_value)

if not return_loss:
seq_len = tokens.shape[1]
mask = self.logits_mask[:, :seq_len]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(mask, max_neg_value)
return logits

assert exists(image), 'when training, image must be supplied'
Expand Down

0 comments on commit 031e3be

Please sign in to comment.