diff --git a/dalle_pytorch/dalle_pytorch.py b/dalle_pytorch/dalle_pytorch.py index bf082eab..9d83b386 100644 --- a/dalle_pytorch/dalle_pytorch.py +++ b/dalle_pytorch/dalle_pytorch.py @@ -268,7 +268,7 @@ def __init__( self.text_emb = nn.Embedding(num_text_tokens, dim) self.image_emb = nn.Embedding(num_image_tokens, dim) - self.text_pos_emb = nn.Embedding(text_seq_len, dim) + self.text_pos_emb = nn.Embedding(text_seq_len + 1, dim) # +1 for self.image_pos_emb = AxialPositionalEmbedding(dim, axial_shape = (image_size, image_size)) self.num_text_tokens = num_text_tokens # for offsetting logits index and calculating cross entropy loss @@ -278,7 +278,7 @@ def __init__( self.image_seq_len = image_seq_len seq_len = text_seq_len + image_seq_len - total_tokens = num_text_tokens + num_image_tokens + 1 # extra for EOS + total_tokens = num_text_tokens + num_image_tokens self.total_tokens = total_tokens self.noncausal_attn_len = noncausal_attn_len @@ -298,7 +298,7 @@ def __init__( reversible = reversible, attn_dropout = attn_dropout, ff_dropout = ff_dropout, - noncausal_attn_len = noncausal_attn_len, + noncausal_attn_len = (noncausal_attn_len + 1), sparse_attn = sparse_attn, sparse_attn_global_indices = range(text_seq_len) ) @@ -315,9 +315,8 @@ def __init__( logits_range = rearrange(logits_range, 'd -> () () d') logits_mask = ( - ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) | - ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) | - ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1))) + ((seq_range >= text_seq_len) & (logits_range < num_text_tokens)) | + ((seq_range < text_seq_len) & (logits_range >= num_text_tokens)) ) self.register_buffer('logits_mask', logits_mask) @@ -338,7 +337,8 @@ def generate_images( vae, text_seq_len, image_seq_len, num_text_tokens = self.vae, self.text_seq_len, self.image_seq_len, self.num_text_tokens total_len = text_seq_len + image_seq_len - out = text + out = F.pad(text, (1, 0), value = 0) + for cur_len in range(text.shape[1], total_len): is_image = cur_len >= text_seq_len @@ -374,9 +374,9 @@ def forward( mask = None, return_loss = False ): - device = text.device - eos_token_id = self.total_tokens - 1 + device, ignore_index = text.device, self.ignore_index + text = F.pad(text, (1, 0), value = 0) # use padding as tokens = self.text_emb(text) tokens += self.text_pos_emb(torch.arange(text.shape[1], device = device)) @@ -393,17 +393,17 @@ def forward( tokens = torch.cat((tokens, image_emb), dim = 1) - seq_len += image_len + seq_len += (image_len - 1) if exists(mask): mask = F.pad(mask, (0, image_emb.shape[1]), value = True) - out = self.transformer(tokens, mask = mask) + out = self.transformer(tokens[:, :-1], mask = mask) logits = self.to_logits(out) # mask logits to make sure text predicts text (except last token), and image predicts image - mask = self.logits_mask[:, :seq_len] + logits_mask = self.logits_mask[:, :seq_len] max_neg_value = -torch.finfo(logits.dtype).max - logits.masked_fill_(mask, max_neg_value) + logits.masked_fill_(logits_mask, max_neg_value) if not return_loss: return logits @@ -411,13 +411,13 @@ def forward( assert exists(image), 'when training, image must be supplied' noncausal_attn_len = self.noncausal_attn_len offsetted_image = image + self.num_text_tokens - labels = torch.cat((text, offsetted_image), dim = 1) + labels = torch.cat((text[:, 1:], offsetted_image), dim = 1) if noncausal_attn_len > 0: seq_range = torch.arange(seq_len, device = device) - mask = seq_range < (noncausal_attn_len - 1) - labels.masked_fill_(mask[None, :], -100) # -100 is the ignore index for cross entropy loss + noncausal_attn_mask = seq_range < noncausal_attn_len + noncausal_attn_mask = rearrange(noncausal_attn_mask, 'n -> () n') + labels.masked_fill_(noncausal_attn_mask, ignore_index) - labels = F.pad(labels, (0, 1), value = eos_token_id) # last token predicts EOS - loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:]) + loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels) return loss diff --git a/dalle_pytorch/transformer.py b/dalle_pytorch/transformer.py index 51127f27..017507e8 100644 --- a/dalle_pytorch/transformer.py +++ b/dalle_pytorch/transformer.py @@ -87,7 +87,7 @@ def forward(self, x, mask = None): i, j = dots.shape[-2:] mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() - if self.noncausal_attn_len > 0: + if self.noncausal_attn_len > 1: ind = slice(0, self.noncausal_attn_len) mask[ind, ind] = False diff --git a/setup.py b/setup.py index 4756f75c..89615e79 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'dalle-pytorch', packages = find_packages(), - version = '0.0.45', + version = '0.0.46', license='MIT', description = 'DALL-E - Pytorch', author = 'Phil Wang',