From 013606b1863a5e4af06b1b3e956f2307d72c2d7d Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 18 Jan 2021 09:02:00 -0800 Subject: [PATCH] omit the prefix sections of the sequence undergoing full attention from the cross entropy loss --- dalle_pytorch/dalle_pytorch.py | 16 ++++++++++++++-- setup.py | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 5b339c5b..60a71006 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -256,6 +256,7 @@ def __init__( ff_dropout = 0, sparse_attn = False, noncausal_attn_len = 0, + ignore_index = -100 ): super().__init__() assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE' @@ -279,7 +280,9 @@ def __init__( seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS self.total_tokens = total_tokens - + + self.noncausal_attn_len = noncausal_attn_len + self.vae = vae if exists(self.vae): self.vae = vae @@ -319,6 +322,8 @@ def __init__( self.register_buffer('logits_mask', logits_mask) + self.ignore_index = ignore_index + @torch.no_grad() @eval_decorator def generate_images( @@ -404,9 +409,16 @@ def forward( return logits assert exists(image), 'when training, image must be supplied' - + noncausal_attn_len = self.noncausal_attn_len offsetted_image = image + self.num_text_tokens labels = torch.cat((text, offsetted_image), dim = 1) + + if noncausal_attn_len > 0: + mask = torch.arange(seq_len, device = device) + mask = mask < noncausal_attn_len + print(mask) + labels.masked_fill_(mask[None, :], -100) # -100 is the ignore index for cross entropy loss + labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:]) return loss diff --git a/setup.py b/setup.py index 6090c28f..b100b33f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.40', + version = '0.0.41', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',