Skip to content

Commit

Permalink
Merge pull request #22 from lucidrains/switch-eos-over-to-bos
Browse files Browse the repository at this point in the history
switch over to using <bos> instead of <eos>, assuming fixed length ge…
  • Loading branch information
lucidrains authored Jan 19, 2021
2 parents 0c3cb2e + 89d19bc commit 6a50564
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
36 changes: 18 additions & 18 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def __init__(
self.text_emb = nn.Embedding(num_text_tokens, dim)
self.image_emb = nn.Embedding(num_image_tokens, dim)

self.text_pos_emb = nn.Embedding(text_seq_len, dim)
self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for <bos>
self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size))

self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss
Expand All @@ -278,7 +278,7 @@ def __init__(
self.image_seq_len = image_seq_len

seq_len = text_seq_len + image_seq_len
total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
total_tokens = num_text_tokens + num_image_tokens
self.total_tokens = total_tokens

self.noncausal_attn_len = noncausal_attn_len
Expand All @@ -298,7 +298,7 @@ def __init__(
reversible = reversible,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
noncausal_attn_len = noncausal_attn_len,
noncausal_attn_len = (noncausal_attn_len + 1),
sparse_attn = sparse_attn,
sparse_attn_global_indices = range(text_seq_len)
)
Expand All @@ -315,9 +315,8 @@ def __init__(
logits_range = rearrange(logits_range, 'd -> () () d')

logits_mask = (
((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) |
((seq_range < text_seq_len) & (logits_range >= num_text_tokens))
)

self.register_buffer('logits_mask', logits_mask)
Expand All @@ -338,7 +337,8 @@ def generate_images(
vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens
total_len = text_seq_len + image_seq_len

out = text
out = F.pad(text, (1, 0), value = 0)

for cur_len in range(text.shape[1], total_len):
is_image = cur_len >= text_seq_len

Expand Down Expand Up @@ -374,9 +374,9 @@ def forward(
mask = None,
return_loss = False
):
device = text.device
eos_token_id = self.total_tokens - 1
device, ignore_index = text.device, self.ignore_index

text = F.pad(text, (1, 0), value = 0) # use padding as <bos>
tokens = self.text_emb(text)
tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device))

Expand All @@ -393,31 +393,31 @@ def forward(

tokens = torch.cat((tokens, image_emb), dim = 1)

seq_len += image_len
seq_len += (image_len - 1)
if exists(mask):
mask = F.pad(mask, (0, image_emb.shape[1]), value = True)

out = self.transformer(tokens, mask = mask)
out = self.transformer(tokens[:, :-1], mask = mask)
logits = self.to_logits(out)

# mask logits to make sure text predicts text (except last token), and image predicts image
mask = self.logits_mask[:, :seq_len]
logits_mask = self.logits_mask[:, :seq_len]
max_neg_value = -torch.finfo(logits.dtype).max
logits.masked_fill_(mask, max_neg_value)
logits.masked_fill_(logits_mask, max_neg_value)

if not return_loss:
return logits

assert exists(image), 'when training, image must be supplied'
noncausal_attn_len = self.noncausal_attn_len
offsetted_image = image + self.num_text_tokens
labels = torch.cat((text, offsetted_image), dim = 1)
labels = torch.cat((text[:, 1:], offsetted_image), dim = 1)

if noncausal_attn_len > 0:
seq_range = torch.arange(seq_len, device = device)
mask = seq_range < (noncausal_attn_len - 1)
labels.masked_fill_(mask[None, :], -100) # -100 is the ignore index for cross entropy loss
noncausal_attn_mask = seq_range < noncausal_attn_len
noncausal_attn_mask = rearrange(noncausal_attn_mask, 'n -> () n')
labels.masked_fill_(noncausal_attn_mask, ignore_index)

labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels)
return loss
2 changes: 1 addition & 1 deletion dalle_pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, x, mask = None):
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()

if self.noncausal_attn_len > 0:
if self.noncausal_attn_len > 1:
ind = slice(0, self.noncausal_attn_len)
mask[ind, ind] = False

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'dalle-pytorch',
packages = find_packages(),
version = '0.0.45',
version = '0.0.46',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6a50564

Please sign in to comment.