diff --git a/compressive_transformer_pytorch/compressive_transformer_pytorch.py b/compressive_transformer_pytorch/compressive_transformer_pytorch.py index dd29516..56eb1b5 100644 --- a/compressive_transformer_pytorch/compressive_transformer_pytorch.py +++ b/compressive_transformer_pytorch/compressive_transformer_pytorch.py @@ -198,10 +198,11 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me if input_mask is not None: mask = input_mask[:, None, :, None] * input_mask[:, None, None, :] - mask = F.pad(mask, (mem_len + cmem_len, 0), value = False) + mask = F.pad(mask, (mem_len + cmem_len, 0), value = True) dots.masked_fill_(~mask, mask_value) - mask = torch.ones(t, kv_len, **to(x)).triu_(diagonal = 1 + kv_len).bool() + total_mem_len = mem_len + cmem_len + mask = torch.ones(t, t + total_mem_len, **to(x)).triu_(diagonal = 1 + total_mem_len).bool() dots.masked_fill_(mask[None, None, ...], mask_value) attn = dots.softmax(dim=-1) @@ -248,7 +249,6 @@ def forward(self, x, memories = None, pos_emb = None, input_mask = None, calc_me full_attn(q, cmem_k, cmem_v) ) - return logits, Memory(new_mem, new_cmem), aux_loss # transformer diff --git a/setup.py b/setup.py index 15ac555..8dbdecf 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'compressive-transformer-pytorch', packages = find_packages(exclude=['examples']), - version = '0.3.8', + version = '0.3.10', license='MIT', description = 'Implementation of Compressive Transformer in Pytorch', author = 'Phil Wang',