diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index 5b339c5b..60a71006 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -256,6 +256,7 @@ def __init__( ff_dropout = 0, sparse_attn = False, noncausal_attn_len = 0, + ignore_index = -100 ): super().__init__() assert isinstance(vae, DiscreteVAE), 'vae must be an instance of DiscreteVAE' @@ -279,7 +280,9 @@ def __init__( seq_len = text_seq_len + image_seq_len total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS self.total_tokens = total_tokens - + + self.noncausal_attn_len = noncausal_attn_len + self.vae = vae if exists(self.vae): self.vae = vae @@ -319,6 +322,8 @@ def __init__( self.register_buffer('logits_mask', logits_mask) + self.ignore_index = ignore_index + @torch.no_grad() @eval_decorator def generate_images( @@ -404,9 +409,16 @@ def forward( return logits assert exists(image), 'when training, image must be supplied' - + noncausal_attn_len = self.noncausal_attn_len offsetted_image = image + self.num_text_tokens labels = torch.cat((text, offsetted_image), dim = 1) + + if noncausal_attn_len > 0: + mask = torch.arange(seq_len, device = device) + mask = mask < noncausal_attn_len + print(mask) + labels.masked_fill_(mask[None, :], -100) # -100 is the ignore index for cross entropy loss + labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:]) return loss