Skip to content

Commit

Permalink
Merge pull request #20 from lucidrains/prefix-full-attention
Browse files Browse the repository at this point in the history
omit the prefix sections of the sequence undergoing full attention fr…
  • Loading branch information
lucidrains authored Jan 18, 2021
2 parents 945055a + 013606b commit 8586b4f
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
ff_dropout = 0,
sparse_attn = False,
noncausal_attn_len = 0,
ignore_index = -100
):
super().__init__()
assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE'
Expand All @@ -279,7 +280,9 @@ def __init__(
seq_len = text_seq_len + image_seq_len
total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS
self.total_tokens = total_tokens


self.noncausal_attn_len = noncausal_attn_len

self.vae = vae
if exists(self.vae):
self.vae = vae
Expand Down Expand Up @@ -319,6 +322,8 @@ def __init__(

self.register_buffer('logits_mask', logits_mask)

self.ignore_index = ignore_index

@torch.no_grad()
@eval_decorator
def generate_images(
Expand Down Expand Up @@ -404,9 +409,16 @@ def forward(
return logits

assert exists(image), 'when training, image must be supplied'

noncausal_attn_len = self.noncausal_attn_len
offsetted_image = image + self.num_text_tokens
labels = torch.cat((text, offsetted_image), dim = 1)

if noncausal_attn_len > 0:
mask = torch.arange(seq_len, device = device)
mask = mask < noncausal_attn_len
print(mask)
labels.masked_fill_(mask[None, :], -100) # -100 is the ignore index for cross entropy loss

labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS
loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])
return loss

0 comments on commit 8586b4f

Please sign in to comment.